Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a323fce6
Commit
a323fce6
authored
Jan 05, 2022
by
WenmuZhou
Browse files
vqa code integrated into ppocr training system
parent
1ded2ac4
Changes
54
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1201 additions
and
26 deletions
+1201
-26
configs/vqa/re/layoutxlm.yml
configs/vqa/re/layoutxlm.yml
+122
-0
configs/vqa/ser/layoutlm.yml
configs/vqa/ser/layoutlm.yml
+120
-0
configs/vqa/ser/layoutxlm.yml
configs/vqa/ser/layoutxlm.yml
+121
-0
ppocr/data/__init__.py
ppocr/data/__init__.py
+7
-1
ppocr/data/collate_fn.py
ppocr/data/collate_fn.py
+22
-4
ppocr/data/imaug/__init__.py
ppocr/data/imaug/__init__.py
+2
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+207
-1
ppocr/data/imaug/operators.py
ppocr/data/imaug/operators.py
+10
-8
ppocr/data/imaug/vqa/__init__.py
ppocr/data/imaug/vqa/__init__.py
+19
-0
ppocr/data/imaug/vqa/token/__init__.py
ppocr/data/imaug/vqa/token/__init__.py
+17
-0
ppocr/data/imaug/vqa/token/vqa_token_chunk.py
ppocr/data/imaug/vqa/token/vqa_token_chunk.py
+117
-0
ppocr/data/imaug/vqa/token/vqa_token_pad.py
ppocr/data/imaug/vqa/token/vqa_token_pad.py
+101
-0
ppocr/data/imaug/vqa/token/vqa_token_relation.py
ppocr/data/imaug/vqa/token/vqa_token_relation.py
+67
-0
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+4
-3
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+8
-1
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+15
-0
ppocr/losses/vqa_token_layoutlm_loss.py
ppocr/losses/vqa_token_layoutlm_loss.py
+14
-7
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+4
-1
ppocr/metrics/vqa_token_re_metric.py
ppocr/metrics/vqa_token_re_metric.py
+177
-0
ppocr/metrics/vqa_token_ser_metric.py
ppocr/metrics/vqa_token_ser_metric.py
+47
-0
No files found.
configs/vqa/re/layoutxlm.yml
0 → 100644
View file @
a323fce6
Global
:
use_gpu
:
True
epoch_num
:
200
log_smooth_window
:
10
print_batch_step
:
10
save_model_dir
:
./output/re_layoutxlm/
save_epoch_step
:
2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step
:
[
0
,
38
]
cal_metric_during_train
:
False
pretrained_model
:
&pretrained_model
layoutxlm-base-uncased
save_inference_dir
:
use_visualdl
:
False
infer_img
:
ppstructure/vqa/images/input/zh_val_21.jpg
save_res_path
:
./output/re/
Architecture
:
model_type
:
vqa
algorithm
:
&algorithm
"
LayoutXLM"
Transform
:
Backbone
:
name
:
LayoutXLMForRe
pretrained_model
:
*pretrained_model
checkpoints
:
Loss
:
name
:
LossFromOutput
key
:
loss
reduction
:
mean
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
clip_norm
:
10
lr
:
learning_rate
:
0.00005
regularizer
:
name
:
Const
factor
:
0.00000
PostProcess
:
name
:
VQAReTokenLayoutLMPostProcess
Metric
:
name
:
VQAReTokenMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
train_data/XFUND/zh_train/image
label_file_list
:
-
train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
VQATokenLabelEncode
:
# Class handling label
contains_re
:
True
algorithm
:
*algorithm
class_path
:
&class_path
ppstructure/vqa/labels/labels_ser.txt
-
VQATokenPad
:
max_seq_len
:
&max_seq_len
512
return_attention_mask
:
True
-
VQAReTokenRelation
:
-
VQAReTokenChunk
:
max_seq_len
:
*max_seq_len
-
Resize
:
size
:
[
224
,
224
]
-
NormalizeImage
:
scale
:
1
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
input_ids'
,
'
bbox'
,
'
image'
,
'
attention_mask'
,
'
token_type_ids'
,
'
entities'
,
'
relations'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
collate_fn
:
ListCollator
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
train_data/XFUND/zh_val/image
label_file_list
:
-
train_data/XFUND/zh_val/xfun_normalize_val.json
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
VQATokenLabelEncode
:
# Class handling label
contains_re
:
True
algorithm
:
*algorithm
class_path
:
*class_path
-
VQATokenPad
:
max_seq_len
:
*max_seq_len
return_attention_mask
:
True
-
VQAReTokenRelation
:
-
VQAReTokenChunk
:
max_seq_len
:
*max_seq_len
-
Resize
:
size
:
[
224
,
224
]
-
NormalizeImage
:
scale
:
1
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
input_ids'
,
'
bbox'
,
'
image'
,
'
attention_mask'
,
'
token_type_ids'
,
'
entities'
,
'
relations'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
collate_fn
:
ListCollator
configs/vqa/ser/layoutlm.yml
0 → 100644
View file @
a323fce6
Global
:
use_gpu
:
True
epoch_num
:
&epoch_num
200
log_smooth_window
:
10
print_batch_step
:
10
save_model_dir
:
./output/ser_layoutlm/
save_epoch_step
:
2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step
:
[
0
,
19
]
cal_metric_during_train
:
False
pretrained_model
:
&pretrained_model
layoutlm-base-uncased
save_inference_dir
:
use_visualdl
:
False
infer_img
:
ppstructure/vqa/images/input/zh_val_0.jpg
save_res_path
:
./output/ser/predicts_layoutlm.txt
Architecture
:
model_type
:
vqa
algorithm
:
&algorithm
"
LayoutLM"
Transform
:
Backbone
:
name
:
LayoutLMForSer
pretrained_model
:
*pretrained_model
checkpoints
:
num_classes
:
&num_classes
7
Loss
:
name
:
VQASerTokenLayoutLMLoss
num_classes
:
*num_classes
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Linear
learning_rate
:
0.00005
epochs
:
*epoch_num
warmup_epoch
:
2
regularizer
:
name
:
Const
factor
:
0.00000
PostProcess
:
name
:
VQASerTokenLayoutLMPostProcess
class_path
:
&class_path
ppstructure/vqa/labels/labels_ser.txt
Metric
:
name
:
VQASerTokenMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
train_data/XFUND/zh_train/image
label_file_list
:
-
train_data/XFUND/zh_train/xfun_normalize_train.json
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
VQATokenLabelEncode
:
# Class handling label
contains_re
:
False
algorithm
:
*algorithm
class_path
:
*class_path
-
VQATokenPad
:
max_seq_len
:
&max_seq_len
512
return_attention_mask
:
True
-
VQASerTokenChunk
:
max_seq_len
:
*max_seq_len
-
Resize
:
size
:
[
224
,
224
]
-
NormalizeImage
:
scale
:
1
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
input_ids'
,
'
labels'
,
'
bbox'
,
'
image'
,
'
attention_mask'
,
'
token_type_ids'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
train_data/XFUND/zh_val/image
label_file_list
:
-
train_data/XFUND/zh_val/xfun_normalize_val.json
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
VQATokenLabelEncode
:
# Class handling label
contains_re
:
False
algorithm
:
*algorithm
class_path
:
*class_path
-
VQATokenPad
:
max_seq_len
:
*max_seq_len
return_attention_mask
:
True
-
VQASerTokenChunk
:
max_seq_len
:
*max_seq_len
-
Resize
:
size
:
[
224
,
224
]
-
NormalizeImage
:
scale
:
1
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
input_ids'
,
'
labels'
,
'
bbox'
,
'
image'
,
'
attention_mask'
,
'
token_type_ids'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
configs/vqa/ser/layoutxlm.yml
0 → 100644
View file @
a323fce6
Global
:
use_gpu
:
True
epoch_num
:
&epoch_num
200
log_smooth_window
:
10
print_batch_step
:
10
save_model_dir
:
./output/ser_layoutxlm/
save_epoch_step
:
2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step
:
[
0
,
19
]
cal_metric_during_train
:
False
pretrained_model
:
&pretrained_model
layoutxlm-base-uncased
save_inference_dir
:
use_visualdl
:
False
infer_img
:
ppstructure/vqa/images/input/zh_val_42.jpg
save_res_path
:
./output/ser
Architecture
:
model_type
:
vqa
algorithm
:
&algorithm
"
LayoutXLM"
Transform
:
Backbone
:
name
:
LayoutXLMForSer
pretrained_model
:
*pretrained_model
checkpoints
:
num_classes
:
&num_classes
7
Loss
:
name
:
VQASerTokenLayoutLMLoss
num_classes
:
*num_classes
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Linear
learning_rate
:
0.00005
epochs
:
*epoch_num
warmup_epoch
:
2
regularizer
:
name
:
Const
factor
:
0.00000
PostProcess
:
name
:
VQASerTokenLayoutLMPostProcess
class_path
:
&class_path
ppstructure/vqa/labels/labels_ser.txt
Metric
:
name
:
VQASerTokenMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
train_data/XFUND/zh_train/image
label_file_list
:
-
train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
VQATokenLabelEncode
:
# Class handling label
contains_re
:
False
algorithm
:
*algorithm
class_path
:
*class_path
-
VQATokenPad
:
max_seq_len
:
&max_seq_len
512
return_attention_mask
:
True
-
VQASerTokenChunk
:
max_seq_len
:
*max_seq_len
-
Resize
:
size
:
[
224
,
224
]
-
NormalizeImage
:
scale
:
1
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
input_ids'
,
'
labels'
,
'
bbox'
,
'
image'
,
'
attention_mask'
,
'
token_type_ids'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
train_data/XFUND/zh_val/image
label_file_list
:
-
train_data/XFUND/zh_val/xfun_normalize_val.json
transforms
:
-
DecodeImage
:
# load image
img_mode
:
RGB
channel_first
:
False
-
VQATokenLabelEncode
:
# Class handling label
contains_re
:
False
algorithm
:
*algorithm
class_path
:
*class_path
-
VQATokenPad
:
max_seq_len
:
*max_seq_len
return_attention_mask
:
True
-
VQASerTokenChunk
:
max_seq_len
:
*max_seq_len
-
Resize
:
size
:
[
224
,
224
]
-
NormalizeImage
:
scale
:
1
mean
:
[
123.675
,
116.28
,
103.53
]
std
:
[
58.395
,
57.12
,
57.375
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
input_ids'
,
'
labels'
,
'
bbox'
,
'
image'
,
'
attention_mask'
,
'
token_type_ids'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
ppocr/data/__init__.py
View file @
a323fce6
...
...
@@ -86,13 +86,19 @@ def build_dataloader(config, mode, device, logger, seed=None):
shuffle
=
shuffle
,
drop_last
=
drop_last
)
if
'collate_fn'
in
loader_config
:
from
.
import
collate_fn
collate_fn
=
getattr
(
collate_fn
,
loader_config
[
'collate_fn'
])()
else
:
collate_fn
=
None
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
num_workers
=
num_workers
,
return_list
=
True
,
use_shared_memory
=
use_shared_memory
)
use_shared_memory
=
use_shared_memory
,
collate_fn
=
collate_fn
)
# support exit using ctrl+c
signal
.
signal
(
signal
.
SIGINT
,
term_mp
)
...
...
pp
structure/vqa
/data
_
collat
or
.py
→
pp
ocr
/data
/
collat
e_fn
.py
View file @
a323fce6
...
...
@@ -15,20 +15,19 @@
import
paddle
import
numbers
import
numpy
as
np
from
collections
import
defaultdict
class
D
ata
Collator
:
class
D
ict
Collator
(
object
)
:
"""
data batch
"""
def
__call__
(
self
,
batch
):
data_dict
=
{}
data_dict
=
defaultdict
(
list
)
to_tensor_keys
=
[]
for
sample
in
batch
:
for
k
,
v
in
sample
.
items
():
if
k
not
in
data_dict
:
data_dict
[
k
]
=
[]
if
isinstance
(
v
,
(
np
.
ndarray
,
paddle
.
Tensor
,
numbers
.
Number
)):
if
k
not
in
to_tensor_keys
:
to_tensor_keys
.
append
(
k
)
...
...
@@ -36,3 +35,22 @@ class DataCollator:
for
k
in
to_tensor_keys
:
data_dict
[
k
]
=
paddle
.
to_tensor
(
data_dict
[
k
])
return
data_dict
class
ListCollator
(
object
):
"""
data batch
"""
def
__call__
(
self
,
batch
):
data_dict
=
defaultdict
(
list
)
to_tensor_idxs
=
[]
for
sample
in
batch
:
for
idx
,
v
in
enumerate
(
sample
):
if
isinstance
(
v
,
(
np
.
ndarray
,
paddle
.
Tensor
,
numbers
.
Number
)):
if
idx
not
in
to_tensor_idxs
:
to_tensor_idxs
.
append
(
idx
)
data_dict
[
idx
].
append
(
v
)
for
idx
in
to_tensor_idxs
:
data_dict
[
idx
]
=
paddle
.
to_tensor
(
data_dict
[
idx
])
return
list
(
data_dict
.
values
())
ppocr/data/imaug/__init__.py
View file @
a323fce6
...
...
@@ -34,6 +34,8 @@ from .sast_process import *
from
.pg_process
import
*
from
.gen_table_mask
import
*
from
.vqa
import
*
def
transform
(
data
,
ops
=
None
):
""" transform """
...
...
ppocr/data/imaug/label_ops.py
View file @
a323fce6
...
...
@@ -17,6 +17,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
copy
import
numpy
as
np
import
string
from
shapely.geometry
import
LineString
,
Point
,
Polygon
...
...
@@ -782,3 +783,208 @@ class SARLabelEncode(BaseRecLabelEncode):
def
get_ignored_tokens
(
self
):
return
[
self
.
padding_idx
]
class
VQATokenLabelEncode
(
object
):
"""
基于NLP的标签编码
"""
def
__init__
(
self
,
class_path
,
contains_re
=
False
,
add_special_ids
=
False
,
algorithm
=
'LayoutXLM'
,
infer_mode
=
False
,
ocr_engine
=
None
,
**
kwargs
):
super
(
VQATokenLabelEncode
,
self
).
__init__
()
from
paddlenlp.transformers
import
LayoutXLMTokenizer
,
LayoutLMTokenizer
from
ppocr.utils.utility
import
load_vqa_bio_label_maps
tokenizer_dict
=
{
'LayoutXLM'
:
{
'class'
:
LayoutXLMTokenizer
,
'pretrained_model'
:
'layoutxlm-base-uncased'
},
'LayoutLM'
:
{
'class'
:
LayoutLMTokenizer
,
'pretrained_model'
:
'layoutlm-base-uncased'
}
}
self
.
contains_re
=
contains_re
tokenizer_config
=
tokenizer_dict
[
algorithm
]
self
.
tokenizer
=
tokenizer_config
[
'class'
].
from_pretrained
(
tokenizer_config
[
'pretrained_model'
])
self
.
label2id_map
,
id2label_map
=
load_vqa_bio_label_maps
(
class_path
)
self
.
add_special_ids
=
add_special_ids
self
.
infer_mode
=
infer_mode
self
.
ocr_engine
=
ocr_engine
def
__call__
(
self
,
data
):
if
self
.
infer_mode
==
False
:
return
self
.
_train
(
data
)
else
:
return
self
.
_infer
(
data
)
def
_train
(
self
,
data
):
info
=
data
[
'label'
]
# read text info
info_dict
=
json
.
loads
(
info
)
height
=
info_dict
[
"height"
]
width
=
info_dict
[
"width"
]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
gt_label_list
=
[]
if
self
.
contains_re
:
# for re
entities
=
[]
relations
=
[]
id2label
=
{}
entity_id_to_index_map
=
{}
empty_entity
=
set
()
for
info
in
info_dict
[
"ocr_info"
]:
if
self
.
contains_re
:
# for re
if
len
(
info
[
"text"
])
==
0
:
empty_entity
.
add
(
info
[
"id"
])
continue
id2label
[
info
[
"id"
]]
=
info
[
"label"
]
relations
.
extend
([
tuple
(
sorted
(
l
))
for
l
in
info
[
"linking"
]])
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
label
=
info
[
"label"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
self
.
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
gt_label
=
[]
if
not
self
.
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
if
label
.
lower
()
==
"other"
:
gt_label
.
extend
([
0
]
*
len
(
encode_res
[
"input_ids"
]))
else
:
gt_label
.
append
(
self
.
label2id_map
[(
"b-"
+
label
).
upper
()])
gt_label
.
extend
([
self
.
label2id_map
[(
"i-"
+
label
).
upper
()]]
*
(
len
(
encode_res
[
"input_ids"
])
-
1
))
if
self
.
contains_re
:
if
gt_label
[
0
]
!=
self
.
label2id_map
[
"O"
]:
entity_id_to_index_map
[
info
[
"id"
]]
=
len
(
entities
)
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
label
.
upper
(),
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
gt_label_list
.
extend
(
gt_label
)
words_list
.
append
(
text
)
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"labels"
:
gt_label_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
}
data
.
update
(
encoded_inputs
)
data
[
'tokenizer_params'
]
=
dict
(
padding_side
=
self
.
tokenizer
.
padding_side
,
pad_token_type_id
=
self
.
tokenizer
.
pad_token_type_id
,
pad_token_id
=
self
.
tokenizer
.
pad_token_id
)
if
self
.
contains_re
:
data
[
'entities'
]
=
entities
data
[
'relations'
]
=
relations
data
[
'id2label'
]
=
id2label
data
[
'empty_entity'
]
=
empty_entity
data
[
'entity_id_to_index_map'
]
=
entity_id_to_index_map
return
data
def
_infer
(
self
,
data
):
def
trans_poly_to_bbox
(
poly
):
x1
=
np
.
min
([
p
[
0
]
for
p
in
poly
])
x2
=
np
.
max
([
p
[
0
]
for
p
in
poly
])
y1
=
np
.
min
([
p
[
1
]
for
p
in
poly
])
y2
=
np
.
max
([
p
[
1
]
for
p
in
poly
])
return
[
x1
,
y1
,
x2
,
y2
]
height
,
width
,
_
=
data
[
'image'
].
shape
ocr_result
=
self
.
ocr_engine
.
ocr
(
data
[
'image'
],
cls
=
False
)
ocr_info
=
[]
for
res
in
ocr_result
:
ocr_info
.
append
({
"text"
:
res
[
1
][
0
],
"bbox"
:
trans_poly_to_bbox
(
res
[
0
]),
"poly"
:
res
[
0
],
})
segment_offset_id
=
[]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
entities
=
[]
for
info
in
ocr_info
:
# x1, y1, x2, y2
bbox
=
copy
.
deepcopy
(
info
[
"bbox"
])
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
self
.
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
if
not
self
.
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
# for re
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
"O"
,
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
words_list
.
append
(
text
)
segment_offset_id
.
append
(
len
(
input_ids_list
))
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
"entities"
:
entities
,
'labels'
:
None
,
'segment_offset_id'
:
segment_offset_id
,
'ocr_info'
:
ocr_info
}
data
.
update
(
encoded_inputs
)
return
data
ppocr/data/imaug/operators.py
View file @
a323fce6
...
...
@@ -170,17 +170,19 @@ class Resize(object):
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
'polys'
in
data
:
text_polys
=
data
[
'polys'
]
img_resize
,
[
ratio_h
,
ratio_w
]
=
self
.
resize_image
(
img
)
if
'polys'
in
data
:
new_boxes
=
[]
for
box
in
text_polys
:
new_box
=
[]
for
cord
in
box
:
new_box
.
append
([
cord
[
0
]
*
ratio_w
,
cord
[
1
]
*
ratio_h
])
new_boxes
.
append
(
new_box
)
data
[
'image'
]
=
img_resize
data
[
'polys'
]
=
np
.
array
(
new_boxes
,
dtype
=
np
.
float32
)
data
[
'image'
]
=
img_resize
return
data
...
...
ppocr/data/imaug/vqa/__init__.py
0 → 100644
View file @
a323fce6
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.token
import
VQATokenPad
,
VQASerTokenChunk
,
VQAReTokenChunk
,
VQAReTokenRelation
__all__
=
[
'VQATokenPad'
,
'VQASerTokenChunk'
,
'VQAReTokenChunk'
,
'VQAReTokenRelation'
]
ppocr/data/imaug/vqa/token/__init__.py
0 → 100644
View file @
a323fce6
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.vqa_token_chunk
import
VQASerTokenChunk
,
VQAReTokenChunk
from
.vqa_token_pad
import
VQATokenPad
from
.vqa_token_relation
import
VQAReTokenRelation
ppocr/data/imaug/vqa/token/vqa_token_chunk.py
0 → 100644
View file @
a323fce6
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
VQASerTokenChunk
(
object
):
def
__init__
(
self
,
max_seq_len
=
512
,
infer_mode
=
False
,
**
kwargs
):
self
.
max_seq_len
=
max_seq_len
self
.
infer_mode
=
infer_mode
def
__call__
(
self
,
data
):
encoded_inputs_all
=
[]
seq_len
=
len
(
data
[
'input_ids'
])
for
index
in
range
(
0
,
seq_len
,
self
.
max_seq_len
):
chunk_beg
=
index
chunk_end
=
min
(
index
+
self
.
max_seq_len
,
seq_len
)
encoded_inputs_example
=
{}
for
key
in
data
:
if
key
in
[
'label'
,
'input_ids'
,
'labels'
,
'token_type_ids'
,
'bbox'
,
'attention_mask'
]:
if
self
.
infer_mode
and
key
==
'labels'
:
encoded_inputs_example
[
key
]
=
data
[
key
]
else
:
encoded_inputs_example
[
key
]
=
data
[
key
][
chunk_beg
:
chunk_end
]
else
:
encoded_inputs_example
[
key
]
=
data
[
key
]
encoded_inputs_all
.
append
(
encoded_inputs_example
)
return
encoded_inputs_all
[
0
]
class
VQAReTokenChunk
(
object
):
def
__init__
(
self
,
max_seq_len
=
512
,
entities_labels
=
None
,
infer_mode
=
False
,
**
kwargs
):
self
.
max_seq_len
=
max_seq_len
self
.
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
if
entities_labels
is
None
else
entities_labels
self
.
infer_mode
=
infer_mode
def
__call__
(
self
,
data
):
# prepare data
entities
=
data
.
pop
(
'entities'
)
relations
=
data
.
pop
(
'relations'
)
encoded_inputs_all
=
[]
for
index
in
range
(
0
,
len
(
data
[
"input_ids"
]),
self
.
max_seq_len
):
item
=
{}
for
key
in
data
:
if
key
in
[
'label'
,
'input_ids'
,
'labels'
,
'token_type_ids'
,
'bbox'
,
'attention_mask'
]:
if
self
.
infer_mode
and
key
==
'labels'
:
item
[
key
]
=
data
[
key
]
else
:
item
[
key
]
=
data
[
key
][
index
:
index
+
self
.
max_seq_len
]
else
:
item
[
key
]
=
data
[
key
]
# select entity in current chunk
entities_in_this_span
=
[]
global_to_local_map
=
{}
#
for
entity_id
,
entity
in
enumerate
(
entities
):
if
(
index
<=
entity
[
"start"
]
<
index
+
self
.
max_seq_len
and
index
<=
entity
[
"end"
]
<
index
+
self
.
max_seq_len
):
entity
[
"start"
]
=
entity
[
"start"
]
-
index
entity
[
"end"
]
=
entity
[
"end"
]
-
index
global_to_local_map
[
entity_id
]
=
len
(
entities_in_this_span
)
entities_in_this_span
.
append
(
entity
)
# select relations in current chunk
relations_in_this_span
=
[]
for
relation
in
relations
:
if
(
index
<=
relation
[
"start_index"
]
<
index
+
self
.
max_seq_len
and
index
<=
relation
[
"end_index"
]
<
index
+
self
.
max_seq_len
):
relations_in_this_span
.
append
({
"head"
:
global_to_local_map
[
relation
[
"head"
]],
"tail"
:
global_to_local_map
[
relation
[
"tail"
]],
"start_index"
:
relation
[
"start_index"
]
-
index
,
"end_index"
:
relation
[
"end_index"
]
-
index
,
})
item
.
update
({
"entities"
:
self
.
reformat
(
entities_in_this_span
),
"relations"
:
self
.
reformat
(
relations_in_this_span
),
})
item
[
'entities'
][
'label'
]
=
[
self
.
entities_labels
[
x
]
for
x
in
item
[
'entities'
][
'label'
]
]
encoded_inputs_all
.
append
(
item
)
return
encoded_inputs_all
[
0
]
def
reformat
(
self
,
data
):
new_data
=
{}
for
item
in
data
:
for
k
,
v
in
item
.
items
():
if
k
not
in
new_data
:
new_data
[
k
]
=
[]
new_data
[
k
].
append
(
v
)
return
new_data
ppocr/data/imaug/vqa/token/vqa_token_pad.py
0 → 100644
View file @
a323fce6
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
numpy
as
np
class
VQATokenPad
(
object
):
def
__init__
(
self
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
truncation_strategy
=
"longest_first"
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
,
infer_mode
=
False
,
**
kwargs
):
self
.
max_seq_len
=
max_seq_len
self
.
pad_to_max_seq_len
=
max_seq_len
self
.
return_attention_mask
=
return_attention_mask
self
.
return_token_type_ids
=
return_token_type_ids
self
.
truncation_strategy
=
truncation_strategy
self
.
return_overflowing_tokens
=
return_overflowing_tokens
self
.
return_special_tokens_mask
=
return_special_tokens_mask
self
.
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
self
.
infer_mode
=
infer_mode
def
__call__
(
self
,
data
):
needs_to_be_padded
=
self
.
pad_to_max_seq_len
and
len
(
data
[
"input_ids"
])
<
self
.
max_seq_len
if
needs_to_be_padded
:
if
'tokenizer_params'
in
data
:
tokenizer_params
=
data
.
pop
(
'tokenizer_params'
)
else
:
tokenizer_params
=
dict
(
padding_side
=
'right'
,
pad_token_type_id
=
0
,
pad_token_id
=
1
)
difference
=
self
.
max_seq_len
-
len
(
data
[
"input_ids"
])
if
tokenizer_params
[
'padding_side'
]
==
'right'
:
if
self
.
return_attention_mask
:
data
[
"attention_mask"
]
=
[
1
]
*
len
(
data
[
"input_ids"
])
+
[
0
]
*
difference
if
self
.
return_token_type_ids
:
data
[
"token_type_ids"
]
=
(
data
[
"token_type_ids"
]
+
[
tokenizer_params
[
'pad_token_type_id'
]]
*
difference
)
if
self
.
return_special_tokens_mask
:
data
[
"special_tokens_mask"
]
=
data
[
"special_tokens_mask"
]
+
[
1
]
*
difference
data
[
"input_ids"
]
=
data
[
"input_ids"
]
+
[
tokenizer_params
[
'pad_token_id'
]
]
*
difference
if
not
self
.
infer_mode
:
data
[
"labels"
]
=
data
[
"labels"
]
+
[
self
.
pad_token_label_id
]
*
difference
data
[
"bbox"
]
=
data
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]]
*
difference
elif
tokenizer_params
[
'padding_side'
]
==
'left'
:
if
self
.
return_attention_mask
:
data
[
"attention_mask"
]
=
[
0
]
*
difference
+
[
1
]
*
len
(
data
[
"input_ids"
])
if
self
.
return_token_type_ids
:
data
[
"token_type_ids"
]
=
(
[
tokenizer_params
[
'pad_token_type_id'
]]
*
difference
+
data
[
"token_type_ids"
])
if
self
.
return_special_tokens_mask
:
data
[
"special_tokens_mask"
]
=
[
1
]
*
difference
+
data
[
"special_tokens_mask"
]
data
[
"input_ids"
]
=
[
tokenizer_params
[
'pad_token_id'
]
]
*
difference
+
data
[
"input_ids"
]
if
not
self
.
infer_mode
:
data
[
"labels"
]
=
[
self
.
pad_token_label_id
]
*
difference
+
data
[
"labels"
]
data
[
"bbox"
]
=
[[
0
,
0
,
0
,
0
]]
*
difference
+
data
[
"bbox"
]
else
:
if
self
.
return_attention_mask
:
data
[
"attention_mask"
]
=
[
1
]
*
len
(
data
[
"input_ids"
])
for
key
in
data
:
if
key
in
[
'input_ids'
,
'labels'
,
'token_type_ids'
,
'bbox'
,
'attention_mask'
]:
if
self
.
infer_mode
and
key
==
'labels'
:
continue
length
=
min
(
len
(
data
[
key
]),
self
.
max_seq_len
)
data
[
key
]
=
np
.
array
(
data
[
key
][:
length
],
dtype
=
'int64'
)
return
data
ppocr/data/imaug/vqa/token/vqa_token_relation.py
0 → 100644
View file @
a323fce6
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
VQAReTokenRelation
(
object
):
def
__init__
(
self
,
**
kwargs
):
pass
def
__call__
(
self
,
data
):
"""
build relations
"""
entities
=
data
[
'entities'
]
relations
=
data
[
'relations'
]
id2label
=
data
.
pop
(
'id2label'
)
empty_entity
=
data
.
pop
(
'empty_entity'
)
entity_id_to_index_map
=
data
.
pop
(
'entity_id_to_index_map'
)
relations
=
list
(
set
(
relations
))
relations
=
[
rel
for
rel
in
relations
if
rel
[
0
]
not
in
empty_entity
and
rel
[
1
]
not
in
empty_entity
]
kv_relations
=
[]
for
rel
in
relations
:
pair
=
[
id2label
[
rel
[
0
]],
id2label
[
rel
[
1
]]]
if
pair
==
[
"question"
,
"answer"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
0
]],
"tail"
:
entity_id_to_index_map
[
rel
[
1
]]
})
elif
pair
==
[
"answer"
,
"question"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
1
]],
"tail"
:
entity_id_to_index_map
[
rel
[
0
]]
})
else
:
continue
relations
=
sorted
(
[{
"head"
:
rel
[
"head"
],
"tail"
:
rel
[
"tail"
],
"start_index"
:
self
.
get_relation_span
(
rel
,
entities
)[
0
],
"end_index"
:
self
.
get_relation_span
(
rel
,
entities
)[
1
],
}
for
rel
in
kv_relations
],
key
=
lambda
x
:
x
[
"head"
],
)
data
[
'relations'
]
=
relations
return
data
def
get_relation_span
(
self
,
rel
,
entities
):
bound
=
[]
for
entity_index
in
[
rel
[
"head"
],
rel
[
"tail"
]]:
bound
.
append
(
entities
[
entity_index
][
"start"
])
bound
.
append
(
entities
[
entity_index
][
"end"
])
return
min
(
bound
),
max
(
bound
)
ppocr/data/simple_dataset.py
View file @
a323fce6
...
...
@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
seed
=
seed
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
...
...
@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
self
.
shuffle_data_random
()
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
...
...
@@ -95,7 +96,7 @@ class SimpleDataSet(Dataset):
data
[
'image'
]
=
img
data
=
transform
(
data
,
load_data_ops
)
if
data
is
None
or
data
[
'polys'
].
shape
[
1
]
!=
4
:
if
data
is
None
or
data
[
'polys'
].
shape
[
1
]
!=
4
:
continue
ext_data
.
append
(
data
)
return
ext_data
...
...
@@ -121,7 +122,7 @@ class SimpleDataSet(Dataset):
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
data_line
,
traceback
.
format_exc
()))
outs
=
None
#
outs = None
if
outs
is
None
:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx
=
np
.
random
.
randint
(
self
.
__len__
(
...
...
ppocr/losses/__init__.py
View file @
a323fce6
...
...
@@ -16,6 +16,9 @@ import copy
import
paddle
import
paddle.nn
as
nn
# basic_loss
from
.basic_loss
import
LossFromOutput
# det loss
from
.det_db_loss
import
DBLoss
from
.det_east_loss
import
EASTLoss
...
...
@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
# table loss
from
.table_att_loss
import
TableAttentionLoss
# vqa token loss
from
.vqa_token_layoutlm_loss
import
VQASerTokenLayoutLMLoss
def
build_loss
(
config
):
support_dict
=
[
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/basic_loss.py
View file @
a323fce6
...
...
@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
def
forward
(
self
,
x
,
y
):
return
self
.
loss_func
(
x
,
y
)
class
LossFromOutput
(
nn
.
Layer
):
def
__init__
(
self
,
key
=
'loss'
,
reduction
=
'none'
):
super
().
__init__
()
self
.
key
=
key
self
.
reduction
=
reduction
def
forward
(
self
,
predicts
,
batch
):
loss
=
predicts
[
self
.
key
]
if
self
.
reduction
==
'mean'
:
loss
=
paddle
.
mean
(
loss
)
elif
self
.
reduction
==
'sum'
:
loss
=
paddle
.
sum
(
loss
)
return
{
'loss'
:
loss
}
pp
structure/vqa/losse
s.py
→
pp
ocr/losses/vqa_token_layoutlm_los
s.py
100644 → 100755
View file @
a323fce6
#
C
opyright (c) 20
2
1 PaddlePaddle Authors. All Rights Reserve
d
.
#
c
opyright (c) 201
9
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
nn
class
SER
Loss
(
nn
.
Layer
):
class
VQASerTokenLayoutLM
Loss
(
nn
.
Layer
):
def
__init__
(
self
,
num_classes
):
super
().
__init__
()
self
.
loss_class
=
nn
.
CrossEntropyLoss
()
self
.
num_classes
=
num_classes
self
.
ignore_index
=
self
.
loss_class
.
ignore_index
def
forward
(
self
,
labels
,
outputs
,
attention_mask
):
def
forward
(
self
,
predicts
,
batch
):
labels
=
batch
[
1
]
attention_mask
=
batch
[
4
]
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
reshape
([
-
1
,
])
==
1
active_outputs
=
outpu
ts
.
reshape
(
active_outputs
=
predic
ts
.
reshape
(
[
-
1
,
self
.
num_classes
])[
active_loss
]
active_labels
=
labels
.
reshape
([
-
1
,
])[
active_loss
]
loss
=
self
.
loss_class
(
active_outputs
,
active_labels
)
else
:
loss
=
self
.
loss_class
(
outputs
.
reshape
([
-
1
,
self
.
num_classes
]),
labels
.
reshape
([
-
1
,
]))
return
loss
predicts
.
reshape
([
-
1
,
self
.
num_classes
]),
labels
.
reshape
([
-
1
,
]))
return
{
'loss'
:
loss
}
ppocr/metrics/__init__.py
View file @
a323fce6
...
...
@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
from
.distillation_metric
import
DistillationMetric
from
.table_metric
import
TableMetric
from
.kie_metric
import
KIEMetric
from
.vqa_token_ser_metric
import
VQASerTokenMetric
from
.vqa_token_re_metric
import
VQAReTokenMetric
def
build_metric
(
config
):
support_dict
=
[
"DetMetric"
,
"RecMetric"
,
"ClsMetric"
,
"E2EMetric"
,
"DistillationMetric"
,
"TableMetric"
,
'KIEMetric'
"DistillationMetric"
,
"TableMetric"
,
'KIEMetric'
,
'VQASerTokenMetric'
,
'VQAReTokenMetric'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/metrics/vqa_token_re_metric.py
0 → 100644
View file @
a323fce6
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
from
seqeval.metrics
import
f1_score
,
precision_score
,
recall_score
__all__
=
[
'KIEMetric'
]
class
VQAReTokenMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'hmean'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
pred_relations
,
relations
,
entities
=
preds
self
.
pred_relations_list
.
extend
(
pred_relations
)
self
.
relations_list
.
extend
(
relations
)
self
.
entities_list
.
extend
(
entities
)
def
get_metric
(
self
):
gt_relations
=
[]
for
b
in
range
(
len
(
self
.
relations_list
)):
rel_sent
=
[]
for
head
,
tail
in
zip
(
self
.
relations_list
[
b
][
"head"
],
self
.
relations_list
[
b
][
"tail"
]):
rel
=
{}
rel
[
"head_id"
]
=
head
rel
[
"head"
]
=
(
self
.
entities_list
[
b
][
"start"
][
rel
[
"head_id"
]],
self
.
entities_list
[
b
][
"end"
][
rel
[
"head_id"
]])
rel
[
"head_type"
]
=
self
.
entities_list
[
b
][
"label"
][
rel
[
"head_id"
]]
rel
[
"tail_id"
]
=
tail
rel
[
"tail"
]
=
(
self
.
entities_list
[
b
][
"start"
][
rel
[
"tail_id"
]],
self
.
entities_list
[
b
][
"end"
][
rel
[
"tail_id"
]])
rel
[
"tail_type"
]
=
self
.
entities_list
[
b
][
"label"
][
rel
[
"tail_id"
]]
rel
[
"type"
]
=
1
rel_sent
.
append
(
rel
)
gt_relations
.
append
(
rel_sent
)
re_metrics
=
self
.
re_score
(
self
.
pred_relations_list
,
gt_relations
,
mode
=
"boundaries"
)
metrics
=
{
"precision"
:
re_metrics
[
"ALL"
][
"p"
],
"recall"
:
re_metrics
[
"ALL"
][
"r"
],
"hmean"
:
re_metrics
[
"ALL"
][
"f1"
],
}
self
.
reset
()
return
metrics
def
reset
(
self
):
self
.
pred_relations_list
=
[]
self
.
relations_list
=
[]
self
.
entities_list
=
[]
def
re_score
(
self
,
pred_relations
,
gt_relations
,
mode
=
"strict"
):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert
mode
in
[
"strict"
,
"boundaries"
]
relation_types
=
[
v
for
v
in
[
0
,
1
]
if
not
v
==
0
]
scores
=
{
rel
:
{
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
}
for
rel
in
relation_types
+
[
"ALL"
]
}
# Count GT relations and Predicted relations
n_sents
=
len
(
gt_relations
)
n_rels
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
gt_relations
])
n_found
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
pred_relations
])
# Count TP, FP and FN per type
for
pred_sent
,
gt_sent
in
zip
(
pred_relations
,
gt_relations
):
for
rel_type
in
relation_types
:
# strict mode takes argument types into account
if
mode
==
"strict"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
# boundaries mode only takes argument spans into account
elif
mode
==
"boundaries"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
scores
[
rel_type
][
"tp"
]
+=
len
(
pred_rels
&
gt_rels
)
scores
[
rel_type
][
"fp"
]
+=
len
(
pred_rels
-
gt_rels
)
scores
[
rel_type
][
"fn"
]
+=
len
(
gt_rels
-
pred_rels
)
# Compute per entity Precision / Recall / F1
for
rel_type
in
scores
.
keys
():
if
scores
[
rel_type
][
"tp"
]:
scores
[
rel_type
][
"p"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fp"
]
+
scores
[
rel_type
][
"tp"
])
scores
[
rel_type
][
"r"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fn"
]
+
scores
[
rel_type
][
"tp"
])
else
:
scores
[
rel_type
][
"p"
],
scores
[
rel_type
][
"r"
]
=
0
,
0
if
not
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]
==
0
:
scores
[
rel_type
][
"f1"
]
=
(
2
*
scores
[
rel_type
][
"p"
]
*
scores
[
rel_type
][
"r"
]
/
(
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]))
else
:
scores
[
rel_type
][
"f1"
]
=
0
# Compute micro F1 Scores
tp
=
sum
([
scores
[
rel_type
][
"tp"
]
for
rel_type
in
relation_types
])
fp
=
sum
([
scores
[
rel_type
][
"fp"
]
for
rel_type
in
relation_types
])
fn
=
sum
([
scores
[
rel_type
][
"fn"
]
for
rel_type
in
relation_types
])
if
tp
:
precision
=
tp
/
(
tp
+
fp
)
recall
=
tp
/
(
tp
+
fn
)
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
else
:
precision
,
recall
,
f1
=
0
,
0
,
0
scores
[
"ALL"
][
"p"
]
=
precision
scores
[
"ALL"
][
"r"
]
=
recall
scores
[
"ALL"
][
"f1"
]
=
f1
scores
[
"ALL"
][
"tp"
]
=
tp
scores
[
"ALL"
][
"fp"
]
=
fp
scores
[
"ALL"
][
"fn"
]
=
fn
# Compute Macro F1 Scores
scores
[
"ALL"
][
"Macro_f1"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"f1"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_p"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"p"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_r"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"r"
]
for
ent_type
in
relation_types
])
return
scores
ppocr/metrics/vqa_token_ser_metric.py
0 → 100644
View file @
a323fce6
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
from
seqeval.metrics
import
f1_score
,
precision_score
,
recall_score
__all__
=
[
'KIEMetric'
]
class
VQASerTokenMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'hmean'
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
reset
()
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
preds
,
labels
=
preds
self
.
pred_list
.
extend
(
preds
)
self
.
gt_list
.
extend
(
labels
)
def
get_metric
(
self
):
metircs
=
{
"precision"
:
precision_score
(
self
.
gt_list
,
self
.
pred_list
),
"recall"
:
recall_score
(
self
.
gt_list
,
self
.
pred_list
),
"hmean"
:
f1_score
(
self
.
gt_list
,
self
.
pred_list
),
}
self
.
reset
()
return
metircs
def
reset
(
self
):
self
.
pred_list
=
[]
self
.
gt_list
=
[]
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment