Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
8bae1e40
Unverified
Commit
8bae1e40
authored
Jan 14, 2022
by
MissPenguin
Committed by
GitHub
Jan 14, 2022
Browse files
Merge pull request #5174 from WenmuZhou/fix_vqa
vqa code integrated into ppocr training system
parents
9fa209e3
1cbe4bf2
Changes
66
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
687 additions
and
25 deletions
+687
-25
ppocr/data/imaug/vqa/token/vqa_token_pad.py
ppocr/data/imaug/vqa/token/vqa_token_pad.py
+104
-0
ppocr/data/imaug/vqa/token/vqa_token_relation.py
ppocr/data/imaug/vqa/token/vqa_token_relation.py
+67
-0
ppocr/data/lmdb_dataset.py
ppocr/data/lmdb_dataset.py
+3
-0
ppocr/data/pgnet_dataset.py
ppocr/data/pgnet_dataset.py
+2
-0
ppocr/data/pubtab_dataset.py
ppocr/data/pubtab_dataset.py
+11
-4
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+2
-1
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+8
-1
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+15
-0
ppocr/losses/vqa_token_layoutlm_loss.py
ppocr/losses/vqa_token_layoutlm_loss.py
+14
-7
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+4
-1
ppocr/metrics/vqa_token_re_metric.py
ppocr/metrics/vqa_token_re_metric.py
+176
-0
ppocr/metrics/vqa_token_ser_metric.py
ppocr/metrics/vqa_token_ser_metric.py
+47
-0
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+8
-3
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+3
-0
ppocr/modeling/backbones/vqa_layoutlm.py
ppocr/modeling/backbones/vqa_layoutlm.py
+125
-0
ppocr/optimizer/__init__.py
ppocr/optimizer/__init__.py
+3
-1
ppocr/optimizer/optimizer.py
ppocr/optimizer/optimizer.py
+35
-0
ppocr/optimizer/regularizer.py
ppocr/optimizer/regularizer.py
+5
-6
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+4
-1
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
+51
-0
No files found.
ppocr/data/imaug/vqa/token/vqa_token_pad.py
0 → 100644
View file @
8bae1e40
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
numpy
as
np
class
VQATokenPad
(
object
):
def
__init__
(
self
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
truncation_strategy
=
"longest_first"
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
,
infer_mode
=
False
,
**
kwargs
):
self
.
max_seq_len
=
max_seq_len
self
.
pad_to_max_seq_len
=
max_seq_len
self
.
return_attention_mask
=
return_attention_mask
self
.
return_token_type_ids
=
return_token_type_ids
self
.
truncation_strategy
=
truncation_strategy
self
.
return_overflowing_tokens
=
return_overflowing_tokens
self
.
return_special_tokens_mask
=
return_special_tokens_mask
self
.
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
self
.
infer_mode
=
infer_mode
def
__call__
(
self
,
data
):
needs_to_be_padded
=
self
.
pad_to_max_seq_len
and
len
(
data
[
"input_ids"
])
<
self
.
max_seq_len
if
needs_to_be_padded
:
if
'tokenizer_params'
in
data
:
tokenizer_params
=
data
.
pop
(
'tokenizer_params'
)
else
:
tokenizer_params
=
dict
(
padding_side
=
'right'
,
pad_token_type_id
=
0
,
pad_token_id
=
1
)
difference
=
self
.
max_seq_len
-
len
(
data
[
"input_ids"
])
if
tokenizer_params
[
'padding_side'
]
==
'right'
:
if
self
.
return_attention_mask
:
data
[
"attention_mask"
]
=
[
1
]
*
len
(
data
[
"input_ids"
])
+
[
0
]
*
difference
if
self
.
return_token_type_ids
:
data
[
"token_type_ids"
]
=
(
data
[
"token_type_ids"
]
+
[
tokenizer_params
[
'pad_token_type_id'
]]
*
difference
)
if
self
.
return_special_tokens_mask
:
data
[
"special_tokens_mask"
]
=
data
[
"special_tokens_mask"
]
+
[
1
]
*
difference
data
[
"input_ids"
]
=
data
[
"input_ids"
]
+
[
tokenizer_params
[
'pad_token_id'
]
]
*
difference
if
not
self
.
infer_mode
:
data
[
"labels"
]
=
data
[
"labels"
]
+
[
self
.
pad_token_label_id
]
*
difference
data
[
"bbox"
]
=
data
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]]
*
difference
elif
tokenizer_params
[
'padding_side'
]
==
'left'
:
if
self
.
return_attention_mask
:
data
[
"attention_mask"
]
=
[
0
]
*
difference
+
[
1
]
*
len
(
data
[
"input_ids"
])
if
self
.
return_token_type_ids
:
data
[
"token_type_ids"
]
=
(
[
tokenizer_params
[
'pad_token_type_id'
]]
*
difference
+
data
[
"token_type_ids"
])
if
self
.
return_special_tokens_mask
:
data
[
"special_tokens_mask"
]
=
[
1
]
*
difference
+
data
[
"special_tokens_mask"
]
data
[
"input_ids"
]
=
[
tokenizer_params
[
'pad_token_id'
]
]
*
difference
+
data
[
"input_ids"
]
if
not
self
.
infer_mode
:
data
[
"labels"
]
=
[
self
.
pad_token_label_id
]
*
difference
+
data
[
"labels"
]
data
[
"bbox"
]
=
[[
0
,
0
,
0
,
0
]]
*
difference
+
data
[
"bbox"
]
else
:
if
self
.
return_attention_mask
:
data
[
"attention_mask"
]
=
[
1
]
*
len
(
data
[
"input_ids"
])
for
key
in
data
:
if
key
in
[
'input_ids'
,
'labels'
,
'token_type_ids'
,
'bbox'
,
'attention_mask'
]:
if
self
.
infer_mode
:
if
key
!=
'labels'
:
length
=
min
(
len
(
data
[
key
]),
self
.
max_seq_len
)
data
[
key
]
=
data
[
key
][:
length
]
else
:
continue
data
[
key
]
=
np
.
array
(
data
[
key
],
dtype
=
'int64'
)
return
data
ppocr/data/imaug/vqa/token/vqa_token_relation.py
0 → 100644
View file @
8bae1e40
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
VQAReTokenRelation
(
object
):
def
__init__
(
self
,
**
kwargs
):
pass
def
__call__
(
self
,
data
):
"""
build relations
"""
entities
=
data
[
'entities'
]
relations
=
data
[
'relations'
]
id2label
=
data
.
pop
(
'id2label'
)
empty_entity
=
data
.
pop
(
'empty_entity'
)
entity_id_to_index_map
=
data
.
pop
(
'entity_id_to_index_map'
)
relations
=
list
(
set
(
relations
))
relations
=
[
rel
for
rel
in
relations
if
rel
[
0
]
not
in
empty_entity
and
rel
[
1
]
not
in
empty_entity
]
kv_relations
=
[]
for
rel
in
relations
:
pair
=
[
id2label
[
rel
[
0
]],
id2label
[
rel
[
1
]]]
if
pair
==
[
"question"
,
"answer"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
0
]],
"tail"
:
entity_id_to_index_map
[
rel
[
1
]]
})
elif
pair
==
[
"answer"
,
"question"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
1
]],
"tail"
:
entity_id_to_index_map
[
rel
[
0
]]
})
else
:
continue
relations
=
sorted
(
[{
"head"
:
rel
[
"head"
],
"tail"
:
rel
[
"tail"
],
"start_index"
:
self
.
get_relation_span
(
rel
,
entities
)[
0
],
"end_index"
:
self
.
get_relation_span
(
rel
,
entities
)[
1
],
}
for
rel
in
kv_relations
],
key
=
lambda
x
:
x
[
"head"
],
)
data
[
'relations'
]
=
relations
return
data
def
get_relation_span
(
self
,
rel
,
entities
):
bound
=
[]
for
entity_index
in
[
rel
[
"head"
],
rel
[
"tail"
]]:
bound
.
append
(
entities
[
entity_index
][
"start"
])
bound
.
append
(
entities
[
entity_index
][
"end"
])
return
min
(
bound
),
max
(
bound
)
ppocr/data/lmdb_dataset.py
View file @
8bae1e40
...
...
@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset):
np
.
random
.
shuffle
(
self
.
data_idx_order_list
)
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
load_hierarchical_lmdb_dataset
(
self
,
data_dir
):
lmdb_sets
=
{}
dataset_idx
=
0
...
...
ppocr/data/pgnet_dataset.py
View file @
8bae1e40
...
...
@@ -49,6 +49,8 @@ class PGDataSet(Dataset):
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
shuffle_data_random
(
self
):
if
self
.
do_shuffle
:
random
.
seed
(
self
.
seed
)
...
...
ppocr/data/pubtab_dataset.py
View file @
8bae1e40
...
...
@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset):
self
.
shuffle_data_random
()
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
shuffle_data_random
(
self
):
if
self
.
do_shuffle
:
random
.
seed
(
self
.
seed
)
...
...
@@ -85,7 +88,11 @@ class PubTabDataSet(Dataset):
cells
=
info
[
'html'
][
'cells'
].
copy
()
structure
=
info
[
'html'
][
'structure'
].
copy
()
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
data
=
{
'img_path'
:
img_path
,
'cells'
:
cells
,
'structure'
:
structure
}
data
=
{
'img_path'
:
img_path
,
'cells'
:
cells
,
'structure'
:
structure
}
if
not
os
.
path
.
exists
(
img_path
):
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
...
...
ppocr/data/simple_dataset.py
View file @
8bae1e40
...
...
@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
seed
=
seed
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
...
...
@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
self
.
shuffle_data_random
()
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
...
...
ppocr/losses/__init__.py
View file @
8bae1e40
...
...
@@ -16,6 +16,9 @@ import copy
import
paddle
import
paddle.nn
as
nn
# basic_loss
from
.basic_loss
import
LossFromOutput
# det loss
from
.det_db_loss
import
DBLoss
from
.det_east_loss
import
EASTLoss
...
...
@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
# table loss
from
.table_att_loss
import
TableAttentionLoss
# vqa token loss
from
.vqa_token_layoutlm_loss
import
VQASerTokenLayoutLMLoss
def
build_loss
(
config
):
support_dict
=
[
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/basic_loss.py
View file @
8bae1e40
...
...
@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
def
forward
(
self
,
x
,
y
):
return
self
.
loss_func
(
x
,
y
)
class
LossFromOutput
(
nn
.
Layer
):
def
__init__
(
self
,
key
=
'loss'
,
reduction
=
'none'
):
super
().
__init__
()
self
.
key
=
key
self
.
reduction
=
reduction
def
forward
(
self
,
predicts
,
batch
):
loss
=
predicts
[
self
.
key
]
if
self
.
reduction
==
'mean'
:
loss
=
paddle
.
mean
(
loss
)
elif
self
.
reduction
==
'sum'
:
loss
=
paddle
.
sum
(
loss
)
return
{
'loss'
:
loss
}
pp
structure/vqa/losse
s.py
→
pp
ocr/losses/vqa_token_layoutlm_los
s.py
100644 → 100755
View file @
8bae1e40
#
C
opyright (c) 2021 PaddlePaddle Authors. All Rights Reserve
d
.
#
c
opyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
nn
class
SER
Loss
(
nn
.
Layer
):
class
VQASerTokenLayoutLM
Loss
(
nn
.
Layer
):
def
__init__
(
self
,
num_classes
):
super
().
__init__
()
self
.
loss_class
=
nn
.
CrossEntropyLoss
()
self
.
num_classes
=
num_classes
self
.
ignore_index
=
self
.
loss_class
.
ignore_index
def
forward
(
self
,
labels
,
outputs
,
attention_mask
):
def
forward
(
self
,
predicts
,
batch
):
labels
=
batch
[
1
]
attention_mask
=
batch
[
4
]
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
reshape
([
-
1
,
])
==
1
active_outputs
=
outpu
ts
.
reshape
(
active_outputs
=
predic
ts
.
reshape
(
[
-
1
,
self
.
num_classes
])[
active_loss
]
active_labels
=
labels
.
reshape
([
-
1
,
])[
active_loss
]
loss
=
self
.
loss_class
(
active_outputs
,
active_labels
)
else
:
loss
=
self
.
loss_class
(
outputs
.
reshape
([
-
1
,
self
.
num_classes
]),
labels
.
reshape
([
-
1
,
]))
return
loss
predicts
.
reshape
([
-
1
,
self
.
num_classes
]),
labels
.
reshape
([
-
1
,
]))
return
{
'loss'
:
loss
}
ppocr/metrics/__init__.py
View file @
8bae1e40
...
...
@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
from
.distillation_metric
import
DistillationMetric
from
.table_metric
import
TableMetric
from
.kie_metric
import
KIEMetric
from
.vqa_token_ser_metric
import
VQASerTokenMetric
from
.vqa_token_re_metric
import
VQAReTokenMetric
def
build_metric
(
config
):
support_dict
=
[
"DetMetric"
,
"RecMetric"
,
"ClsMetric"
,
"E2EMetric"
,
"DistillationMetric"
,
"TableMetric"
,
'KIEMetric'
"DistillationMetric"
,
"TableMetric"
,
'KIEMetric'
,
'VQASerTokenMetric'
,
'VQAReTokenMetric'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/metrics/vqa_token_re_metric.py
0 → 100644
View file @
8bae1e40
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
__all__
=
[
'KIEMetric'
]
class
VQAReTokenMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'hmean'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
pred_relations
,
relations
,
entities
=
preds
self
.
pred_relations_list
.
extend
(
pred_relations
)
self
.
relations_list
.
extend
(
relations
)
self
.
entities_list
.
extend
(
entities
)
def
get_metric
(
self
):
gt_relations
=
[]
for
b
in
range
(
len
(
self
.
relations_list
)):
rel_sent
=
[]
for
head
,
tail
in
zip
(
self
.
relations_list
[
b
][
"head"
],
self
.
relations_list
[
b
][
"tail"
]):
rel
=
{}
rel
[
"head_id"
]
=
head
rel
[
"head"
]
=
(
self
.
entities_list
[
b
][
"start"
][
rel
[
"head_id"
]],
self
.
entities_list
[
b
][
"end"
][
rel
[
"head_id"
]])
rel
[
"head_type"
]
=
self
.
entities_list
[
b
][
"label"
][
rel
[
"head_id"
]]
rel
[
"tail_id"
]
=
tail
rel
[
"tail"
]
=
(
self
.
entities_list
[
b
][
"start"
][
rel
[
"tail_id"
]],
self
.
entities_list
[
b
][
"end"
][
rel
[
"tail_id"
]])
rel
[
"tail_type"
]
=
self
.
entities_list
[
b
][
"label"
][
rel
[
"tail_id"
]]
rel
[
"type"
]
=
1
rel_sent
.
append
(
rel
)
gt_relations
.
append
(
rel_sent
)
re_metrics
=
self
.
re_score
(
self
.
pred_relations_list
,
gt_relations
,
mode
=
"boundaries"
)
metrics
=
{
"precision"
:
re_metrics
[
"ALL"
][
"p"
],
"recall"
:
re_metrics
[
"ALL"
][
"r"
],
"hmean"
:
re_metrics
[
"ALL"
][
"f1"
],
}
self
.
reset
()
return
metrics
def
reset
(
self
):
self
.
pred_relations_list
=
[]
self
.
relations_list
=
[]
self
.
entities_list
=
[]
def
re_score
(
self
,
pred_relations
,
gt_relations
,
mode
=
"strict"
):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert
mode
in
[
"strict"
,
"boundaries"
]
relation_types
=
[
v
for
v
in
[
0
,
1
]
if
not
v
==
0
]
scores
=
{
rel
:
{
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
}
for
rel
in
relation_types
+
[
"ALL"
]
}
# Count GT relations and Predicted relations
n_sents
=
len
(
gt_relations
)
n_rels
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
gt_relations
])
n_found
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
pred_relations
])
# Count TP, FP and FN per type
for
pred_sent
,
gt_sent
in
zip
(
pred_relations
,
gt_relations
):
for
rel_type
in
relation_types
:
# strict mode takes argument types into account
if
mode
==
"strict"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
# boundaries mode only takes argument spans into account
elif
mode
==
"boundaries"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
scores
[
rel_type
][
"tp"
]
+=
len
(
pred_rels
&
gt_rels
)
scores
[
rel_type
][
"fp"
]
+=
len
(
pred_rels
-
gt_rels
)
scores
[
rel_type
][
"fn"
]
+=
len
(
gt_rels
-
pred_rels
)
# Compute per entity Precision / Recall / F1
for
rel_type
in
scores
.
keys
():
if
scores
[
rel_type
][
"tp"
]:
scores
[
rel_type
][
"p"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fp"
]
+
scores
[
rel_type
][
"tp"
])
scores
[
rel_type
][
"r"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fn"
]
+
scores
[
rel_type
][
"tp"
])
else
:
scores
[
rel_type
][
"p"
],
scores
[
rel_type
][
"r"
]
=
0
,
0
if
not
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]
==
0
:
scores
[
rel_type
][
"f1"
]
=
(
2
*
scores
[
rel_type
][
"p"
]
*
scores
[
rel_type
][
"r"
]
/
(
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]))
else
:
scores
[
rel_type
][
"f1"
]
=
0
# Compute micro F1 Scores
tp
=
sum
([
scores
[
rel_type
][
"tp"
]
for
rel_type
in
relation_types
])
fp
=
sum
([
scores
[
rel_type
][
"fp"
]
for
rel_type
in
relation_types
])
fn
=
sum
([
scores
[
rel_type
][
"fn"
]
for
rel_type
in
relation_types
])
if
tp
:
precision
=
tp
/
(
tp
+
fp
)
recall
=
tp
/
(
tp
+
fn
)
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
else
:
precision
,
recall
,
f1
=
0
,
0
,
0
scores
[
"ALL"
][
"p"
]
=
precision
scores
[
"ALL"
][
"r"
]
=
recall
scores
[
"ALL"
][
"f1"
]
=
f1
scores
[
"ALL"
][
"tp"
]
=
tp
scores
[
"ALL"
][
"fp"
]
=
fp
scores
[
"ALL"
][
"fn"
]
=
fn
# Compute Macro F1 Scores
scores
[
"ALL"
][
"Macro_f1"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"f1"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_p"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"p"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_r"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"r"
]
for
ent_type
in
relation_types
])
return
scores
ppocr/metrics/vqa_token_ser_metric.py
0 → 100644
View file @
8bae1e40
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
__all__
=
[
'KIEMetric'
]
class
VQASerTokenMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'hmean'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
preds
,
labels
=
preds
self
.
pred_list
.
extend
(
preds
)
self
.
gt_list
.
extend
(
labels
)
def
get_metric
(
self
):
from
seqeval.metrics
import
f1_score
,
precision_score
,
recall_score
metircs
=
{
"precision"
:
precision_score
(
self
.
gt_list
,
self
.
pred_list
),
"recall"
:
recall_score
(
self
.
gt_list
,
self
.
pred_list
),
"hmean"
:
f1_score
(
self
.
gt_list
,
self
.
pred_list
),
}
self
.
reset
()
return
metircs
def
reset
(
self
):
self
.
pred_list
=
[]
self
.
gt_list
=
[]
ppocr/modeling/architectures/base_model.py
View file @
8bae1e40
...
...
@@ -63,6 +63,10 @@ class BaseModel(nn.Layer):
in_channels
=
self
.
neck
.
out_channels
# # build head, head is need for det, rec and cls
if
'Head'
not
in
config
or
config
[
'Head'
]
is
None
:
self
.
use_head
=
False
else
:
self
.
use_head
=
True
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
...
...
@@ -77,6 +81,7 @@ class BaseModel(nn.Layer):
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
y
[
"neck_out"
]
=
x
if
self
.
use_head
:
x
=
self
.
head
(
x
,
targets
=
data
)
if
isinstance
(
x
,
dict
):
y
.
update
(
x
)
...
...
ppocr/modeling/backbones/__init__.py
View file @
8bae1e40
...
...
@@ -43,6 +43,9 @@ def build_backbone(config, model_type):
from
.table_resnet_vd
import
ResNet
from
.table_mobilenet_v3
import
MobileNetV3
support_dict
=
[
"ResNet"
,
"MobileNetV3"
]
elif
model_type
==
'vqa'
:
from
.vqa_layoutlm
import
LayoutLMForSer
,
LayoutXLMForSer
,
LayoutXLMForRe
support_dict
=
[
"LayoutLMForSer"
,
"LayoutXLMForSer"
,
'LayoutXLMForRe'
]
else
:
raise
NotImplementedError
...
...
ppocr/modeling/backbones/vqa_layoutlm.py
0 → 100644
View file @
8bae1e40
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
from
paddle
import
nn
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMForTokenClassification
,
LayoutXLMForRelationExtraction
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMForTokenClassification
__all__
=
[
"LayoutXLMForSer"
,
'LayoutLMForSer'
]
pretrained_model_dict
=
{
LayoutXLMModel
:
'layoutxlm-base-uncased'
,
LayoutLMModel
:
'layoutlm-base-uncased'
}
class
NLPBaseModel
(
nn
.
Layer
):
def
__init__
(
self
,
base_model_class
,
model_class
,
type
=
'ser'
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
NLPBaseModel
,
self
).
__init__
()
if
checkpoints
is
not
None
:
self
.
model
=
model_class
.
from_pretrained
(
checkpoints
)
else
:
pretrained_model_name
=
pretrained_model_dict
[
base_model_class
]
if
pretrained
:
base_model
=
base_model_class
.
from_pretrained
(
pretrained_model_name
)
else
:
base_model
=
base_model_class
(
**
base_model_class
.
pretrained_init_configuration
[
pretrained_model_name
])
if
type
==
'ser'
:
self
.
model
=
model_class
(
base_model
,
num_classes
=
kwargs
[
'num_classes'
],
dropout
=
None
)
else
:
self
.
model
=
model_class
(
base_model
,
dropout
=
None
)
self
.
out_channels
=
1
class
LayoutXLMForSer
(
NLPBaseModel
):
def
__init__
(
self
,
num_classes
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutXLMForSer
,
self
).
__init__
(
LayoutXLMModel
,
LayoutXLMForTokenClassification
,
'ser'
,
pretrained
,
checkpoints
,
num_classes
=
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
2
],
image
=
x
[
3
],
attention_mask
=
x
[
4
],
token_type_ids
=
x
[
5
],
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
)
return
x
[
0
]
class
LayoutLMForSer
(
NLPBaseModel
):
def
__init__
(
self
,
num_classes
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutLMForSer
,
self
).
__init__
(
LayoutLMModel
,
LayoutLMForTokenClassification
,
'ser'
,
pretrained
,
checkpoints
,
num_classes
=
num_classes
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
2
],
attention_mask
=
x
[
4
],
token_type_ids
=
x
[
5
],
position_ids
=
None
,
output_hidden_states
=
False
)
return
x
class
LayoutXLMForRe
(
NLPBaseModel
):
def
__init__
(
self
,
pretrained
=
True
,
checkpoints
=
None
,
**
kwargs
):
super
(
LayoutXLMForRe
,
self
).
__init__
(
LayoutXLMModel
,
LayoutXLMForRelationExtraction
,
're'
,
pretrained
,
checkpoints
)
def
forward
(
self
,
x
):
x
=
self
.
model
(
input_ids
=
x
[
0
],
bbox
=
x
[
1
],
labels
=
None
,
image
=
x
[
2
],
attention_mask
=
x
[
3
],
token_type_ids
=
x
[
4
],
position_ids
=
None
,
head_mask
=
None
,
entities
=
x
[
5
],
relations
=
x
[
6
])
return
x
ppocr/optimizer/__init__.py
View file @
8bae1e40
...
...
@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
reg_config
=
config
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg_name
=
reg_config
.
pop
(
'name'
)
if
not
hasattr
(
regularizer
,
reg_name
):
reg_name
+=
'Decay'
reg
=
getattr
(
regularizer
,
reg_name
)(
**
reg_config
)()
else
:
reg
=
None
...
...
ppocr/optimizer/optimizer.py
View file @
8bae1e40
...
...
@@ -158,3 +158,38 @@ class Adadelta(object):
name
=
self
.
name
,
parameters
=
parameters
)
return
opt
class
AdamW
(
object
):
def
__init__
(
self
,
learning_rate
=
0.001
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-08
,
weight_decay
=
0.01
,
grad_clip
=
None
,
name
=
None
,
lazy_mode
=
False
,
**
kwargs
):
self
.
learning_rate
=
learning_rate
self
.
beta1
=
beta1
self
.
beta2
=
beta2
self
.
epsilon
=
epsilon
self
.
learning_rate
=
learning_rate
self
.
weight_decay
=
0.01
if
weight_decay
is
None
else
weight_decay
self
.
grad_clip
=
grad_clip
self
.
name
=
name
self
.
lazy_mode
=
lazy_mode
def
__call__
(
self
,
parameters
):
opt
=
optim
.
AdamW
(
learning_rate
=
self
.
learning_rate
,
beta1
=
self
.
beta1
,
beta2
=
self
.
beta2
,
epsilon
=
self
.
epsilon
,
weight_decay
=
self
.
weight_decay
,
grad_clip
=
self
.
grad_clip
,
name
=
self
.
name
,
lazy_mode
=
self
.
lazy_mode
,
parameters
=
parameters
)
return
opt
ppocr/optimizer/regularizer.py
View file @
8bae1e40
...
...
@@ -29,24 +29,23 @@ class L1Decay(object):
def
__init__
(
self
,
factor
=
0.0
):
super
(
L1Decay
,
self
).
__init__
()
self
.
regularization_
coeff
=
factor
self
.
coeff
=
factor
def
__call__
(
self
):
reg
=
paddle
.
regularizer
.
L1Decay
(
self
.
regularization_
coeff
)
reg
=
paddle
.
regularizer
.
L1Decay
(
self
.
coeff
)
return
reg
class
L2Decay
(
object
):
"""
L2 Weight Decay Regularization, which
encourages the weights to be sparse
.
L2 Weight Decay Regularization, which
helps to prevent the model over-fitting
.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def
__init__
(
self
,
factor
=
0.0
):
super
(
L2Decay
,
self
).
__init__
()
self
.
regularization_
coeff
=
factor
self
.
coeff
=
float
(
factor
)
def
__call__
(
self
):
reg
=
paddle
.
regularizer
.
L2Decay
(
self
.
regularization_coeff
)
return
reg
return
self
.
coeff
\ No newline at end of file
ppocr/postprocess/__init__.py
View file @
8bae1e40
...
...
@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
SEEDLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
from
.vqa_token_re_layoutlm_postprocess
import
VQAReTokenLayoutLMPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
...
...
@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
0 → 100644
View file @
8bae1e40
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
class
VQAReTokenLayoutLMPostProcess
(
object
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
**
kwargs
):
super
(
VQAReTokenLayoutLMPostProcess
,
self
).
__init__
()
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
label
is
not
None
:
return
self
.
_metric
(
preds
,
label
)
else
:
return
self
.
_infer
(
preds
,
*
args
,
**
kwargs
)
def
_metric
(
self
,
preds
,
label
):
return
preds
[
'pred_relations'
],
label
[
6
],
label
[
5
]
def
_infer
(
self
,
preds
,
*
args
,
**
kwargs
):
ser_results
=
kwargs
[
'ser_results'
]
entity_idx_dict_batch
=
kwargs
[
'entity_idx_dict_batch'
]
pred_relations
=
preds
[
'pred_relations'
]
# merge relations and ocr info
results
=
[]
for
pred_relation
,
ser_result
,
entity_idx_dict
in
zip
(
pred_relations
,
ser_results
,
entity_idx_dict_batch
):
result
=
[]
used_tail_id
=
[]
for
relation
in
pred_relation
:
if
relation
[
'tail_id'
]
in
used_tail_id
:
continue
used_tail_id
.
append
(
relation
[
'tail_id'
])
ocr_info_head
=
ser_result
[
entity_idx_dict
[
relation
[
'head_id'
]]]
ocr_info_tail
=
ser_result
[
entity_idx_dict
[
relation
[
'tail_id'
]]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
results
.
append
(
result
)
return
results
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment