Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a323fce6
Commit
a323fce6
authored
Jan 05, 2022
by
WenmuZhou
Browse files
vqa code integrated into ppocr training system
parent
1ded2ac4
Changes
54
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
394 additions
and
1870 deletions
+394
-1870
ppstructure/vqa/infer_ser_e2e.py
ppstructure/vqa/infer_ser_e2e.py
+0
-156
ppstructure/vqa/infer_ser_re_e2e.py
ppstructure/vqa/infer_ser_re_e2e.py
+0
-135
ppstructure/vqa/metric.py
ppstructure/vqa/metric.py
+0
-175
ppstructure/vqa/requirements.txt
ppstructure/vqa/requirements.txt
+2
-1
ppstructure/vqa/train_re.py
ppstructure/vqa/train_re.py
+0
-229
ppstructure/vqa/train_ser.py
ppstructure/vqa/train_ser.py
+0
-248
ppstructure/vqa/vqa_utils.py
ppstructure/vqa/vqa_utils.py
+0
-400
ppstructure/vqa/xfun.py
ppstructure/vqa/xfun.py
+0
-464
requirements.txt
requirements.txt
+0
-1
tools/eval.py
tools/eval.py
+2
-1
tools/infer_vqa_token_ser.py
tools/infer_vqa_token_ser.py
+135
-0
tools/infer_vqa_token_ser_re.py
tools/infer_vqa_token_ser_re.py
+199
-0
tools/program.py
tools/program.py
+54
-59
tools/train.py
tools/train.py
+2
-1
No files found.
ppstructure/vqa/infer_ser_e2e.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
import
json
import
cv2
import
numpy
as
np
from
copy
import
deepcopy
from
PIL
import
Image
import
paddle
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
# relative reference
from
vqa_utils
import
parse_args
,
get_image_file_list
,
draw_ser_results
,
get_bio_label_maps
from
vqa_utils
import
pad_sentences
,
split_page
,
preprocess
,
postprocess
,
merge_preds_list_with_ocr_info
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
trans_poly_to_bbox
(
poly
):
x1
=
np
.
min
([
p
[
0
]
for
p
in
poly
])
x2
=
np
.
max
([
p
[
0
]
for
p
in
poly
])
y1
=
np
.
min
([
p
[
1
]
for
p
in
poly
])
y2
=
np
.
max
([
p
[
1
]
for
p
in
poly
])
return
[
x1
,
y1
,
x2
,
y2
]
def
parse_ocr_info_for_ser
(
ocr_result
):
ocr_info
=
[]
for
res
in
ocr_result
:
ocr_info
.
append
({
"text"
:
res
[
1
][
0
],
"bbox"
:
trans_poly_to_bbox
(
res
[
0
]),
"poly"
:
res
[
0
],
})
return
ocr_info
class
SerPredictor
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
max_seq_length
=
args
.
max_seq_length
# init ser token and model
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
self
.
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
.
eval
()
# init ocr_engine
from
paddleocr
import
PaddleOCR
self
.
ocr_engine
=
PaddleOCR
(
rec_model_dir
=
args
.
rec_model_dir
,
det_model_dir
=
args
.
det_model_dir
,
use_angle_cls
=
False
,
show_log
=
False
)
# init dict
label2id_map
,
self
.
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
self
.
label2id_map_for_draw
=
dict
()
for
key
in
label2id_map
:
if
key
.
startswith
(
"I-"
):
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
"B"
+
key
[
1
:]]
else
:
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
key
]
def
__call__
(
self
,
img
):
ocr_result
=
self
.
ocr_engine
.
ocr
(
img
,
cls
=
False
)
ocr_info
=
parse_ocr_info_for_ser
(
ocr_result
)
inputs
=
preprocess
(
tokenizer
=
self
.
tokenizer
,
ori_img
=
img
,
ocr_info
=
ocr_info
,
max_seq_len
=
self
.
max_seq_length
)
if
self
.
args
.
ser_model_type
==
'LayoutLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
elif
self
.
args
.
ser_model_type
==
'LayoutXLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
preds
=
preds
[
0
]
preds
=
postprocess
(
inputs
[
"attention_mask"
],
preds
,
self
.
id2label_map
)
ocr_info
=
merge_preds_list_with_ocr_info
(
ocr_info
,
inputs
[
"segment_offset_id"
],
preds
,
self
.
label2id_map_for_draw
)
return
ocr_info
,
inputs
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# get infer img list
infer_imgs
=
get_image_file_list
(
args
.
infer_imgs
)
# loop for infer
ser_engine
=
SerPredictor
(
args
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
print
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
result
,
_
=
ser_engine
(
img
)
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ser_resule"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
ppstructure/vqa/infer_ser_re_e2e.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
json
import
cv2
import
numpy
as
np
from
copy
import
deepcopy
from
PIL
import
Image
import
paddle
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForRelationExtraction
# relative reference
from
vqa_utils
import
parse_args
,
get_image_file_list
,
draw_re_results
from
infer_ser_e2e
import
SerPredictor
def
make_input
(
ser_input
,
ser_result
,
max_seq_len
=
512
):
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
entities
=
ser_input
[
'entities'
][
0
]
assert
len
(
entities
)
==
len
(
ser_result
)
# entities
start
=
[]
end
=
[]
label
=
[]
entity_idx_dict
=
{}
for
i
,
(
res
,
entity
)
in
enumerate
(
zip
(
ser_result
,
entities
)):
if
res
[
'pred'
]
==
'O'
:
continue
entity_idx_dict
[
len
(
start
)]
=
i
start
.
append
(
entity
[
'start'
])
end
.
append
(
entity
[
'end'
])
label
.
append
(
entities_labels
[
res
[
'pred'
]])
entities
=
dict
(
start
=
start
,
end
=
end
,
label
=
label
)
# relations
head
=
[]
tail
=
[]
for
i
in
range
(
len
(
entities
[
"label"
])):
for
j
in
range
(
len
(
entities
[
"label"
])):
if
entities
[
"label"
][
i
]
==
1
and
entities
[
"label"
][
j
]
==
2
:
head
.
append
(
i
)
tail
.
append
(
j
)
relations
=
dict
(
head
=
head
,
tail
=
tail
)
batch_size
=
ser_input
[
"input_ids"
].
shape
[
0
]
entities_batch
=
[]
relations_batch
=
[]
for
b
in
range
(
batch_size
):
entities_batch
.
append
(
entities
)
relations_batch
.
append
(
relations
)
ser_input
[
'entities'
]
=
entities_batch
ser_input
[
'relations'
]
=
relations_batch
ser_input
.
pop
(
'segment_offset_id'
)
return
ser_input
,
entity_idx_dict
class
SerReSystem
(
object
):
def
__init__
(
self
,
args
):
self
.
ser_engine
=
SerPredictor
(
args
)
self
.
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
re_model_name_or_path
)
self
.
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
re_model_name_or_path
)
self
.
model
.
eval
()
def
__call__
(
self
,
img
):
ser_result
,
ser_inputs
=
self
.
ser_engine
(
img
)
re_input
,
entity_idx_dict
=
make_input
(
ser_inputs
,
ser_result
)
re_result
=
self
.
model
(
**
re_input
)
pred_relations
=
re_result
[
'pred_relations'
][
0
]
# 进行 relations 到 ocr信息的转换
result
=
[]
used_tail_id
=
[]
for
relation
in
pred_relations
:
if
relation
[
'tail_id'
]
in
used_tail_id
:
continue
used_tail_id
.
append
(
relation
[
'tail_id'
])
ocr_info_head
=
ser_result
[
entity_idx_dict
[
relation
[
'head_id'
]]]
ocr_info_tail
=
ser_result
[
entity_idx_dict
[
relation
[
'tail_id'
]]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
return
result
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# get infer img list
infer_imgs
=
get_image_file_list
(
args
.
infer_imgs
)
# loop for infer
ser_re_engine
=
SerReSystem
(
args
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_re.jpg"
)
print
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
result
=
ser_re_engine
(
img
)
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"result"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_re_results
(
img
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
ppstructure/vqa/metric.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
numpy
as
np
import
logging
logger
=
logging
.
getLogger
(
__name__
)
PREFIX_CHECKPOINT_DIR
=
"checkpoint"
_re_checkpoint
=
re
.
compile
(
r
"^"
+
PREFIX_CHECKPOINT_DIR
+
r
"\-(\d+)$"
)
def
get_last_checkpoint
(
folder
):
content
=
os
.
listdir
(
folder
)
checkpoints
=
[
path
for
path
in
content
if
_re_checkpoint
.
search
(
path
)
is
not
None
and
os
.
path
.
isdir
(
os
.
path
.
join
(
folder
,
path
))
]
if
len
(
checkpoints
)
==
0
:
return
return
os
.
path
.
join
(
folder
,
max
(
checkpoints
,
key
=
lambda
x
:
int
(
_re_checkpoint
.
search
(
x
).
groups
()[
0
])))
def
re_score
(
pred_relations
,
gt_relations
,
mode
=
"strict"
):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert
mode
in
[
"strict"
,
"boundaries"
]
relation_types
=
[
v
for
v
in
[
0
,
1
]
if
not
v
==
0
]
scores
=
{
rel
:
{
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
}
for
rel
in
relation_types
+
[
"ALL"
]
}
# Count GT relations and Predicted relations
n_sents
=
len
(
gt_relations
)
n_rels
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
gt_relations
])
n_found
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
pred_relations
])
# Count TP, FP and FN per type
for
pred_sent
,
gt_sent
in
zip
(
pred_relations
,
gt_relations
):
for
rel_type
in
relation_types
:
# strict mode takes argument types into account
if
mode
==
"strict"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
# boundaries mode only takes argument spans into account
elif
mode
==
"boundaries"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
scores
[
rel_type
][
"tp"
]
+=
len
(
pred_rels
&
gt_rels
)
scores
[
rel_type
][
"fp"
]
+=
len
(
pred_rels
-
gt_rels
)
scores
[
rel_type
][
"fn"
]
+=
len
(
gt_rels
-
pred_rels
)
# Compute per entity Precision / Recall / F1
for
rel_type
in
scores
.
keys
():
if
scores
[
rel_type
][
"tp"
]:
scores
[
rel_type
][
"p"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fp"
]
+
scores
[
rel_type
][
"tp"
])
scores
[
rel_type
][
"r"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fn"
]
+
scores
[
rel_type
][
"tp"
])
else
:
scores
[
rel_type
][
"p"
],
scores
[
rel_type
][
"r"
]
=
0
,
0
if
not
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]
==
0
:
scores
[
rel_type
][
"f1"
]
=
(
2
*
scores
[
rel_type
][
"p"
]
*
scores
[
rel_type
][
"r"
]
/
(
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]))
else
:
scores
[
rel_type
][
"f1"
]
=
0
# Compute micro F1 Scores
tp
=
sum
([
scores
[
rel_type
][
"tp"
]
for
rel_type
in
relation_types
])
fp
=
sum
([
scores
[
rel_type
][
"fp"
]
for
rel_type
in
relation_types
])
fn
=
sum
([
scores
[
rel_type
][
"fn"
]
for
rel_type
in
relation_types
])
if
tp
:
precision
=
tp
/
(
tp
+
fp
)
recall
=
tp
/
(
tp
+
fn
)
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
else
:
precision
,
recall
,
f1
=
0
,
0
,
0
scores
[
"ALL"
][
"p"
]
=
precision
scores
[
"ALL"
][
"r"
]
=
recall
scores
[
"ALL"
][
"f1"
]
=
f1
scores
[
"ALL"
][
"tp"
]
=
tp
scores
[
"ALL"
][
"fp"
]
=
fp
scores
[
"ALL"
][
"fn"
]
=
fn
# Compute Macro F1 Scores
scores
[
"ALL"
][
"Macro_f1"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"f1"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_p"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"p"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_r"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"r"
]
for
ent_type
in
relation_types
])
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
# logger.info(
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
# n_sents, n_rels, n_found, tp
# )
# )
# logger.info(
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
# )
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
# logger.info(
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
# )
# )
# for rel_type in relation_types:
# logger.info(
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
# rel_type,
# scores[rel_type]["tp"],
# scores[rel_type]["fp"],
# scores[rel_type]["fn"],
# scores[rel_type]["p"],
# scores[rel_type]["r"],
# scores[rel_type]["f1"],
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
# )
# )
return
scores
ppstructure/vqa/requirements.txt
View file @
a323fce6
sentencepiece
yacs
seqeval
\ No newline at end of file
seqeval
paddlenlp>=2.2.1
\ No newline at end of file
ppstructure/vqa/train_re.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
time
import
numpy
as
np
import
paddle
from
paddlenlp.transformers
import
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForRelationExtraction
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
,
set_seed
from
data_collator
import
DataCollator
from
eval_re
import
evaluate
from
ppocr.utils.logging
import
get_logger
def
train
(
args
):
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
print_arguments
(
args
,
logger
)
# Added here for reproducibility (even between python 2 and 3)
set_seed
(
args
.
seed
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
# dist mode
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
if
not
args
.
resume
:
model
=
LayoutXLMModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
LayoutXLMForRelationExtraction
(
model
,
dropout
=
None
)
logger
.
info
(
'train from scratch'
)
else
:
logger
.
info
(
'resume from {}'
.
format
(
args
.
model_name_or_path
))
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
model_name_or_path
)
# dist mode
if
distributed
:
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
train_data_dir
,
label_path
=
args
.
train_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
DataCollator
())
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
shuffle
=
False
,
collate_fn
=
DataCollator
())
t_total
=
len
(
train_dataloader
)
*
args
.
num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
args
.
learning_rate
,
decay_steps
=
t_total
,
end_lr
=
0.0
,
power
=
1.0
)
if
args
.
warmup_steps
>
0
:
lr_scheduler
=
paddle
.
optimizer
.
lr
.
LinearWarmup
(
lr_scheduler
,
args
.
warmup_steps
,
start_lr
=
0
,
end_lr
=
args
.
learning_rate
,
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
10
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
,
parameters
=
model
.
parameters
(),
epsilon
=
args
.
adam_epsilon
,
grad_clip
=
grad_clip
,
weight_decay
=
args
.
weight_decay
)
# Train!
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = {}"
.
format
(
len
(
train_dataset
)))
logger
.
info
(
" Num Epochs = {}"
.
format
(
args
.
num_train_epochs
))
logger
.
info
(
" Instantaneous batch size per GPU = {}"
.
format
(
args
.
per_gpu_train_batch_size
))
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = {}"
.
format
(
args
.
per_gpu_train_batch_size
*
paddle
.
distributed
.
get_world_size
()))
logger
.
info
(
" Total optimization steps = {}"
.
format
(
t_total
))
global_step
=
0
model
.
clear_gradients
()
train_dataloader_len
=
len
(
train_dataloader
)
best_metirc
=
{
'f1'
:
0
}
model
.
train
()
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
print_step
=
1
for
epoch
in
range
(
int
(
args
.
num_train_epochs
)):
for
step
,
batch
in
enumerate
(
train_dataloader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
train_start
=
time
.
time
()
outputs
=
model
(
**
batch
)
train_run_cost
+=
time
.
time
()
-
train_start
# model outputs are always tuple in ppnlp (see doc)
loss
=
outputs
[
'loss'
]
loss
=
loss
.
mean
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
# lr_scheduler.step() # Update learning rate schedule
global_step
+=
1
total_samples
+=
batch
[
'image'
].
shape
[
0
]
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch
,
args
.
num_train_epochs
,
step
,
train_dataloader_len
,
global_step
,
np
.
mean
(
loss
.
numpy
()),
optimizer
.
get_lr
(),
train_reader_cost
/
print_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_step
,
total_samples
/
print_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
)))
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results
=
evaluate
(
model
,
eval_dataloader
,
logger
)
if
results
[
'f1'
]
>=
best_metirc
[
'f1'
]:
best_metirc
=
results
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"eval results: {}"
.
format
(
results
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
train
(
args
)
ppstructure/vqa/train_ser.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
time
import
copy
import
logging
import
argparse
import
paddle
import
numpy
as
np
from
seqeval.metrics
import
classification_report
,
f1_score
,
precision_score
,
recall_score
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
,
set_seed
from
eval_ser
import
evaluate
from
losses
import
SERLoss
from
ppocr.utils.logging
import
get_logger
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
train
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
print_arguments
(
args
,
logger
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
loss_class
=
SERLoss
(
len
(
label2id_map
))
pad_token_label_id
=
loss_class
.
ignore_index
# dist mode
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
if
not
args
.
resume
:
base_model
=
base_model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
(
base_model
,
num_classes
=
len
(
label2id_map
),
dropout
=
None
)
logger
.
info
(
'train from scratch'
)
else
:
logger
.
info
(
'resume from {}'
.
format
(
args
.
model_name_or_path
))
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
# dist mode
if
distributed
:
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
train_data_dir
,
label_path
=
args
.
train_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
False
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
False
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
None
,
)
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
None
,
)
t_total
=
len
(
train_dataloader
)
*
args
.
num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
args
.
learning_rate
,
decay_steps
=
t_total
,
end_lr
=
0.0
,
power
=
1.0
)
if
args
.
warmup_steps
>
0
:
lr_scheduler
=
paddle
.
optimizer
.
lr
.
LinearWarmup
(
lr_scheduler
,
args
.
warmup_steps
,
start_lr
=
0
,
end_lr
=
args
.
learning_rate
,
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
parameters
=
model
.
parameters
(),
epsilon
=
args
.
adam_epsilon
,
weight_decay
=
args
.
weight_decay
)
# Train!
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed) = %d"
,
args
.
per_gpu_train_batch_size
*
paddle
.
distributed
.
get_world_size
(),
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
tr_loss
=
0.0
set_seed
(
args
.
seed
)
best_metrics
=
None
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
print_step
=
1
model
.
train
()
for
epoch_id
in
range
(
args
.
num_train_epochs
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
if
args
.
ser_model_type
==
'LayoutLM'
:
if
'image'
in
batch
:
batch
.
pop
(
'image'
)
labels
=
batch
.
pop
(
'labels'
)
train_start
=
time
.
time
()
outputs
=
model
(
**
batch
)
train_run_cost
+=
time
.
time
()
-
train_start
if
args
.
ser_model_type
==
'LayoutXLM'
:
outputs
=
outputs
[
0
]
loss
=
loss_class
(
labels
,
outputs
,
batch
[
'attention_mask'
])
# model outputs are always tuple in ppnlp (see doc)
loss
=
loss
.
mean
()
loss
.
backward
()
tr_loss
+=
loss
.
item
()
optimizer
.
step
()
lr_scheduler
.
step
()
# Update learning rate schedule
optimizer
.
clear_grad
()
global_step
+=
1
total_samples
+=
batch
[
'input_ids'
].
shape
[
0
]
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
global_step
,
loss
.
numpy
()[
0
],
lr_scheduler
.
get_lr
(),
train_reader_cost
/
print_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_step
,
total_samples
/
print_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
)))
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
)
if
best_metrics
is
None
or
results
[
"f1"
]
>=
best_metrics
[
"f1"
]:
best_metrics
=
copy
.
deepcopy
(
results
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"[epoch {}/{}][iter: {}/{}] results: {}"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
results
))
if
best_metrics
is
not
None
:
logger
.
info
(
"best metrics: {}"
.
format
(
best_metrics
))
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
return
global_step
,
tr_loss
/
global_step
if
__name__
==
"__main__"
:
args
=
parse_args
()
train
(
args
)
ppstructure/vqa/vqa_utils.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
cv2
import
random
import
numpy
as
np
import
imghdr
from
copy
import
deepcopy
import
paddle
from
PIL
import
Image
,
ImageDraw
,
ImageFont
def
set_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
def
get_bio_label_maps
(
label_map_path
):
with
open
(
label_map_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
lines
=
[
line
.
strip
()
for
line
in
lines
]
if
"O"
not
in
lines
:
lines
.
insert
(
0
,
"O"
)
labels
=
[]
for
line
in
lines
:
if
line
==
"O"
:
labels
.
append
(
"O"
)
else
:
labels
.
append
(
"B-"
+
line
)
labels
.
append
(
"I-"
+
line
)
label2id_map
=
{
label
:
idx
for
idx
,
label
in
enumerate
(
labels
)}
id2label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
labels
)}
return
label2id_map
,
id2label_map
def
get_image_file_list
(
img_file
):
imgs_lists
=
[]
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'GIF'
}
if
os
.
path
.
isfile
(
img_file
)
and
imghdr
.
what
(
img_file
)
in
img_end
:
imgs_lists
.
append
(
img_file
)
elif
os
.
path
.
isdir
(
img_file
):
for
single_file
in
os
.
listdir
(
img_file
):
file_path
=
os
.
path
.
join
(
img_file
,
single_file
)
if
os
.
path
.
isfile
(
file_path
)
and
imghdr
.
what
(
file_path
)
in
img_end
:
imgs_lists
.
append
(
file_path
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
imgs_lists
=
sorted
(
imgs_lists
)
return
imgs_lists
def
draw_ser_results
(
image
,
ocr_results
,
font_path
=
"../../doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
2021
)
color
=
(
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)))
color_map
=
{
idx
:
(
color
[
0
][
idx
],
color
[
1
][
idx
],
color
[
2
][
idx
])
for
idx
in
range
(
1
,
255
)
}
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
for
ocr_info
in
ocr_results
:
if
ocr_info
[
"pred_id"
]
not
in
color_map
:
continue
color
=
color_map
[
ocr_info
[
"pred_id"
]]
text
=
"{}: {}"
.
format
(
ocr_info
[
"pred"
],
ocr_info
[
"text"
])
draw_box_txt
(
ocr_info
[
"bbox"
],
text
,
draw
,
font
,
font_size
,
color
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
def
draw_box_txt
(
bbox
,
text
,
draw
,
font
,
font_size
,
color
):
# draw ocr results outline
bbox
=
((
bbox
[
0
],
bbox
[
1
]),
(
bbox
[
2
],
bbox
[
3
]))
draw
.
rectangle
(
bbox
,
fill
=
color
)
# draw ocr results
start_y
=
max
(
0
,
bbox
[
0
][
1
]
-
font_size
)
tw
=
font
.
getsize
(
text
)[
0
]
draw
.
rectangle
(
[(
bbox
[
0
][
0
]
+
1
,
start_y
),
(
bbox
[
0
][
0
]
+
tw
+
1
,
start_y
+
font_size
)],
fill
=
(
0
,
0
,
255
))
draw
.
text
((
bbox
[
0
][
0
]
+
1
,
start_y
),
text
,
fill
=
(
255
,
255
,
255
),
font
=
font
)
def
draw_re_results
(
image
,
result
,
font_path
=
"../../doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
0
)
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
color_head
=
(
0
,
0
,
255
)
color_tail
=
(
255
,
0
,
0
)
color_line
=
(
0
,
255
,
0
)
for
ocr_info_head
,
ocr_info_tail
in
result
:
draw_box_txt
(
ocr_info_head
[
"bbox"
],
ocr_info_head
[
"text"
],
draw
,
font
,
font_size
,
color_head
)
draw_box_txt
(
ocr_info_tail
[
"bbox"
],
ocr_info_tail
[
"text"
],
draw
,
font
,
font_size
,
color_tail
)
center_head
=
(
(
ocr_info_head
[
'bbox'
][
0
]
+
ocr_info_head
[
'bbox'
][
2
])
//
2
,
(
ocr_info_head
[
'bbox'
][
1
]
+
ocr_info_head
[
'bbox'
][
3
])
//
2
)
center_tail
=
(
(
ocr_info_tail
[
'bbox'
][
0
]
+
ocr_info_tail
[
'bbox'
][
2
])
//
2
,
(
ocr_info_tail
[
'bbox'
][
1
]
+
ocr_info_tail
[
'bbox'
][
3
])
//
2
)
draw
.
line
([
center_head
,
center_tail
],
fill
=
color_line
,
width
=
5
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
# pad sentences
def
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
# Padding with larger size, reshape is carried out
max_seq_len
=
(
len
(
encoded_inputs
[
"input_ids"
])
//
max_seq_len
+
1
)
*
max_seq_len
needs_to_be_padded
=
pad_to_max_seq_len
and
\
max_seq_len
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_seq_len
if
needs_to_be_padded
:
difference
=
max_seq_len
-
len
(
encoded_inputs
[
"input_ids"
])
if
tokenizer
.
padding_side
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
encoded_inputs
[
"token_type_ids"
]
+
[
tokenizer
.
pad_token_type_id
]
*
difference
)
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
tokenizer
.
pad_token_id
]
*
difference
encoded_inputs
[
"bbox"
]
=
encoded_inputs
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]
]
*
difference
else
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
return
encoded_inputs
def
split_page
(
encoded_inputs
,
max_seq_len
=
512
):
"""
truncate is often used in training process
"""
for
key
in
encoded_inputs
:
if
key
==
'entities'
:
encoded_inputs
[
key
]
=
[
encoded_inputs
[
key
]]
continue
encoded_inputs
[
key
]
=
paddle
.
to_tensor
(
encoded_inputs
[
key
])
if
encoded_inputs
[
key
].
ndim
<=
1
:
# for input_ids, att_mask and so on
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
([
-
1
,
max_seq_len
])
else
:
# for bbox
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
(
[
-
1
,
max_seq_len
,
4
])
return
encoded_inputs
def
preprocess
(
tokenizer
,
ori_img
,
ocr_info
,
img_size
=
(
224
,
224
),
pad_token_label_id
=-
100
,
max_seq_len
=
512
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
):
ocr_info
=
deepcopy
(
ocr_info
)
height
=
ori_img
.
shape
[
0
]
width
=
ori_img
.
shape
[
1
]
img
=
cv2
.
resize
(
ori_img
,
img_size
).
transpose
([
2
,
0
,
1
]).
astype
(
np
.
float32
)
segment_offset_id
=
[]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
entities
=
[]
for
info
in
ocr_info
:
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
if
not
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
# for re
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
"O"
,
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
words_list
.
append
(
text
)
segment_offset_id
.
append
(
len
(
input_ids_list
))
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
"entities"
:
entities
}
encoded_inputs
=
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
max_seq_len
,
return_attention_mask
=
return_attention_mask
)
encoded_inputs
=
split_page
(
encoded_inputs
)
fake_bs
=
encoded_inputs
[
"input_ids"
].
shape
[
0
]
encoded_inputs
[
"image"
]
=
paddle
.
to_tensor
(
img
).
unsqueeze
(
0
).
expand
(
[
fake_bs
]
+
list
(
img
.
shape
))
encoded_inputs
[
"segment_offset_id"
]
=
segment_offset_id
return
encoded_inputs
def
postprocess
(
attention_mask
,
preds
,
id2label_map
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
preds_list
=
[[]
for
_
in
range
(
preds
.
shape
[
0
])]
# keep batch info
for
i
in
range
(
preds
.
shape
[
0
]):
for
j
in
range
(
preds
.
shape
[
1
]):
if
attention_mask
[
i
][
j
]
==
1
:
preds_list
[
i
].
append
(
id2label_map
[
preds
[
i
][
j
]])
return
preds_list
def
merge_preds_list_with_ocr_info
(
ocr_info
,
segment_offset_id
,
preds_list
,
label2id_map_for_draw
):
# must ensure the preds_list is generated from the same image
preds
=
[
p
for
pred
in
preds_list
for
p
in
pred
]
id2label_map
=
dict
()
for
key
in
label2id_map_for_draw
:
val
=
label2id_map_for_draw
[
key
]
if
key
==
"O"
:
id2label_map
[
val
]
=
key
if
key
.
startswith
(
"B-"
)
or
key
.
startswith
(
"I-"
):
id2label_map
[
val
]
=
key
[
2
:]
else
:
id2label_map
[
val
]
=
key
for
idx
in
range
(
len
(
segment_offset_id
)):
if
idx
==
0
:
start_id
=
0
else
:
start_id
=
segment_offset_id
[
idx
-
1
]
end_id
=
segment_offset_id
[
idx
]
curr_pred
=
preds
[
start_id
:
end_id
]
curr_pred
=
[
label2id_map_for_draw
[
p
]
for
p
in
curr_pred
]
if
len
(
curr_pred
)
<=
0
:
pred_id
=
0
else
:
counts
=
np
.
bincount
(
curr_pred
)
pred_id
=
np
.
argmax
(
counts
)
ocr_info
[
idx
][
"pred_id"
]
=
int
(
pred_id
)
ocr_info
[
idx
][
"pred"
]
=
id2label_map
[
int
(
pred_id
)]
return
ocr_info
def
print_arguments
(
args
,
logger
=
None
):
print_func
=
logger
.
info
if
logger
is
not
None
else
print
"""print arguments"""
print_func
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print_func
(
'%s: %s'
%
(
arg
,
value
))
print_func
(
'------------------------------------------------'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
# Required parameters
# yapf: disable
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,)
parser
.
add_argument
(
"--ser_model_type"
,
default
=
'LayoutXLM'
,
type
=
str
)
parser
.
add_argument
(
"--re_model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--train_data_dir"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--train_label_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--eval_data_dir"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--eval_label_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
512
,
type
=
int
,)
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
"store_true"
,)
parser
.
add_argument
(
"--num_workers"
,
default
=
8
,
type
=
int
,)
parser
.
add_argument
(
"--per_gpu_train_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
,)
parser
.
add_argument
(
"--per_gpu_eval_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU/CPU for eval."
,)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
,)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
,)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
,)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
,)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
,)
parser
.
add_argument
(
"--eval_steps"
,
type
=
int
,
default
=
10
,
help
=
"eval every X updates steps."
,)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
2048
,
help
=
"random seed for initialization"
,)
parser
.
add_argument
(
"--rec_model_dir"
,
default
=
None
,
type
=
str
,
)
parser
.
add_argument
(
"--det_model_dir"
,
default
=
None
,
type
=
str
,
)
parser
.
add_argument
(
"--label_map_path"
,
default
=
"./labels/labels_ser.txt"
,
type
=
str
,
required
=
False
,
)
parser
.
add_argument
(
"--infer_imgs"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--resume"
,
action
=
'store_true'
)
parser
.
add_argument
(
"--ocr_json_path"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"ocr prediction results"
)
# yapf: enable
args
=
parser
.
parse_args
()
return
args
ppstructure/vqa/xfun.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
cv2
import
numpy
as
np
import
paddle
import
copy
from
paddle.io
import
Dataset
__all__
=
[
"XFUNDataset"
]
class
XFUNDataset
(
Dataset
):
"""
Example:
print("=====begin to build dataset=====")
from paddlenlp.transformers import LayoutXLMTokenizer
tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
tok_res = tokenizer.tokenize("Maribyrnong")
# res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
dataset = XfunDatasetForSer(
tokenizer,
data_dir="./zh.val/",
label_path="zh.val/xfun_normalize_val.json",
img_size=(224,224))
print(len(dataset))
data = dataset[0]
print(data.keys())
print("input_ids: ", data["input_ids"])
print("labels: ", data["labels"])
print("token_type_ids: ", data["token_type_ids"])
print("words_list: ", data["words_list"])
print("image shape: ", data["image"].shape)
"""
def
__init__
(
self
,
tokenizer
,
data_dir
,
label_path
,
contains_re
=
False
,
label2id_map
=
None
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
None
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
,
max_seq_len
=
512
):
super
().
__init__
()
self
.
tokenizer
=
tokenizer
self
.
data_dir
=
data_dir
self
.
label_path
=
label_path
self
.
contains_re
=
contains_re
self
.
label2id_map
=
label2id_map
self
.
img_size
=
img_size
self
.
pad_token_label_id
=
pad_token_label_id
self
.
add_special_ids
=
add_special_ids
self
.
return_attention_mask
=
return_attention_mask
self
.
load_mode
=
load_mode
self
.
max_seq_len
=
max_seq_len
if
self
.
pad_token_label_id
is
None
:
self
.
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
self
.
all_lines
=
self
.
read_all_lines
()
self
.
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
self
.
return_keys
=
{
'bbox'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'input_ids'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'labels'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'attention_mask'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'image'
:
{
'type'
:
'np'
,
'dtype'
:
'float32'
},
'token_type_ids'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'entities'
:
{
'type'
:
'dict'
},
'relations'
:
{
'type'
:
'dict'
}
}
if
load_mode
==
"all"
:
self
.
encoded_inputs_all
=
self
.
_parse_label_file_all
()
def
pad_sentences
(
self
,
encoded_inputs
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
truncation_strategy
=
"longest_first"
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
# Padding
needs_to_be_padded
=
pad_to_max_seq_len
and
\
max_seq_len
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_seq_len
if
needs_to_be_padded
:
difference
=
max_seq_len
-
len
(
encoded_inputs
[
"input_ids"
])
if
self
.
tokenizer
.
padding_side
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
encoded_inputs
[
"token_type_ids"
]
+
[
self
.
tokenizer
.
pad_token_type_id
]
*
difference
)
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
self
.
tokenizer
.
pad_token_id
]
*
difference
encoded_inputs
[
"labels"
]
=
encoded_inputs
[
"labels"
]
+
[
self
.
pad_token_label_id
]
*
difference
encoded_inputs
[
"bbox"
]
=
encoded_inputs
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]]
*
difference
elif
self
.
tokenizer
.
padding_side
==
'left'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
0
]
*
difference
+
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
[
self
.
tokenizer
.
pad_token_type_id
]
*
difference
+
encoded_inputs
[
"token_type_ids"
])
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
[
1
]
*
difference
+
encoded_inputs
[
"special_tokens_mask"
]
encoded_inputs
[
"input_ids"
]
=
[
self
.
tokenizer
.
pad_token_id
]
*
difference
+
encoded_inputs
[
"input_ids"
]
encoded_inputs
[
"labels"
]
=
[
self
.
pad_token_label_id
]
*
difference
+
encoded_inputs
[
"labels"
]
encoded_inputs
[
"bbox"
]
=
[
[
0
,
0
,
0
,
0
]
]
*
difference
+
encoded_inputs
[
"bbox"
]
else
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
return
encoded_inputs
def
truncate_inputs
(
self
,
encoded_inputs
,
max_seq_len
=
512
):
for
key
in
encoded_inputs
:
if
key
==
"sample_id"
:
continue
length
=
min
(
len
(
encoded_inputs
[
key
]),
max_seq_len
)
encoded_inputs
[
key
]
=
encoded_inputs
[
key
][:
length
]
return
encoded_inputs
def
read_all_lines
(
self
,
):
with
open
(
self
.
label_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
return
lines
def
_parse_label_file_all
(
self
):
"""
parse all samples
"""
encoded_inputs_all
=
[]
for
line
in
self
.
all_lines
:
encoded_inputs_all
.
extend
(
self
.
_parse_label_file
(
line
))
return
encoded_inputs_all
def
_parse_label_file
(
self
,
line
):
"""
parse single sample
"""
image_name
,
info_str
=
line
.
split
(
"
\t
"
)
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
image_name
)
def
add_imgge_path
(
x
):
x
[
'image_path'
]
=
image_path
return
x
encoded_inputs
=
self
.
_read_encoded_inputs_sample
(
info_str
)
if
self
.
contains_re
:
encoded_inputs
=
self
.
_chunk_re
(
encoded_inputs
)
else
:
encoded_inputs
=
self
.
_chunk_ser
(
encoded_inputs
)
encoded_inputs
=
list
(
map
(
add_imgge_path
,
encoded_inputs
))
return
encoded_inputs
def
_read_encoded_inputs_sample
(
self
,
info_str
):
"""
parse label info
"""
# read text info
info_dict
=
json
.
loads
(
info_str
)
height
=
info_dict
[
"height"
]
width
=
info_dict
[
"width"
]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
gt_label_list
=
[]
if
self
.
contains_re
:
# for re
entities
=
[]
relations
=
[]
id2label
=
{}
entity_id_to_index_map
=
{}
empty_entity
=
set
()
for
info
in
info_dict
[
"ocr_info"
]:
if
self
.
contains_re
:
# for re
if
len
(
info
[
"text"
])
==
0
:
empty_entity
.
add
(
info
[
"id"
])
continue
id2label
[
info
[
"id"
]]
=
info
[
"label"
]
relations
.
extend
([
tuple
(
sorted
(
l
))
for
l
in
info
[
"linking"
]])
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
label
=
info
[
"label"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
self
.
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
gt_label
=
[]
if
not
self
.
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
if
label
.
lower
()
==
"other"
:
gt_label
.
extend
([
0
]
*
len
(
encode_res
[
"input_ids"
]))
else
:
gt_label
.
append
(
self
.
label2id_map
[(
"b-"
+
label
).
upper
()])
gt_label
.
extend
([
self
.
label2id_map
[(
"i-"
+
label
).
upper
()]]
*
(
len
(
encode_res
[
"input_ids"
])
-
1
))
if
self
.
contains_re
:
if
gt_label
[
0
]
!=
self
.
label2id_map
[
"O"
]:
entity_id_to_index_map
[
info
[
"id"
]]
=
len
(
entities
)
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
label
.
upper
(),
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
gt_label_list
.
extend
(
gt_label
)
words_list
.
append
(
text
)
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"labels"
:
gt_label_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
# "words_list": words_list,
}
encoded_inputs
=
self
.
pad_sentences
(
encoded_inputs
,
max_seq_len
=
self
.
max_seq_len
,
return_attention_mask
=
self
.
return_attention_mask
)
encoded_inputs
=
self
.
truncate_inputs
(
encoded_inputs
)
if
self
.
contains_re
:
relations
=
self
.
_relations
(
entities
,
relations
,
id2label
,
empty_entity
,
entity_id_to_index_map
)
encoded_inputs
[
'relations'
]
=
relations
encoded_inputs
[
'entities'
]
=
entities
return
encoded_inputs
def
_chunk_ser
(
self
,
encoded_inputs
):
encoded_inputs_all
=
[]
seq_len
=
len
(
encoded_inputs
[
'input_ids'
])
chunk_size
=
512
for
chunk_id
,
index
in
enumerate
(
range
(
0
,
seq_len
,
chunk_size
)):
chunk_beg
=
index
chunk_end
=
min
(
index
+
chunk_size
,
seq_len
)
encoded_inputs_example
=
{}
for
key
in
encoded_inputs
:
encoded_inputs_example
[
key
]
=
encoded_inputs
[
key
][
chunk_beg
:
chunk_end
]
encoded_inputs_all
.
append
(
encoded_inputs_example
)
return
encoded_inputs_all
def
_chunk_re
(
self
,
encoded_inputs
):
# prepare data
entities
=
encoded_inputs
.
pop
(
'entities'
)
relations
=
encoded_inputs
.
pop
(
'relations'
)
encoded_inputs_all
=
[]
chunk_size
=
512
for
chunk_id
,
index
in
enumerate
(
range
(
0
,
len
(
encoded_inputs
[
"input_ids"
]),
chunk_size
)):
item
=
{}
for
k
in
encoded_inputs
:
item
[
k
]
=
encoded_inputs
[
k
][
index
:
index
+
chunk_size
]
# select entity in current chunk
entities_in_this_span
=
[]
global_to_local_map
=
{}
#
for
entity_id
,
entity
in
enumerate
(
entities
):
if
(
index
<=
entity
[
"start"
]
<
index
+
chunk_size
and
index
<=
entity
[
"end"
]
<
index
+
chunk_size
):
entity
[
"start"
]
=
entity
[
"start"
]
-
index
entity
[
"end"
]
=
entity
[
"end"
]
-
index
global_to_local_map
[
entity_id
]
=
len
(
entities_in_this_span
)
entities_in_this_span
.
append
(
entity
)
# select relations in current chunk
relations_in_this_span
=
[]
for
relation
in
relations
:
if
(
index
<=
relation
[
"start_index"
]
<
index
+
chunk_size
and
index
<=
relation
[
"end_index"
]
<
index
+
chunk_size
):
relations_in_this_span
.
append
({
"head"
:
global_to_local_map
[
relation
[
"head"
]],
"tail"
:
global_to_local_map
[
relation
[
"tail"
]],
"start_index"
:
relation
[
"start_index"
]
-
index
,
"end_index"
:
relation
[
"end_index"
]
-
index
,
})
item
.
update
({
"entities"
:
reformat
(
entities_in_this_span
),
"relations"
:
reformat
(
relations_in_this_span
),
})
item
[
'entities'
][
'label'
]
=
[
self
.
entities_labels
[
x
]
for
x
in
item
[
'entities'
][
'label'
]
]
encoded_inputs_all
.
append
(
item
)
return
encoded_inputs_all
def
_relations
(
self
,
entities
,
relations
,
id2label
,
empty_entity
,
entity_id_to_index_map
):
"""
build relations
"""
relations
=
list
(
set
(
relations
))
relations
=
[
rel
for
rel
in
relations
if
rel
[
0
]
not
in
empty_entity
and
rel
[
1
]
not
in
empty_entity
]
kv_relations
=
[]
for
rel
in
relations
:
pair
=
[
id2label
[
rel
[
0
]],
id2label
[
rel
[
1
]]]
if
pair
==
[
"question"
,
"answer"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
0
]],
"tail"
:
entity_id_to_index_map
[
rel
[
1
]]
})
elif
pair
==
[
"answer"
,
"question"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
1
]],
"tail"
:
entity_id_to_index_map
[
rel
[
0
]]
})
else
:
continue
relations
=
sorted
(
[{
"head"
:
rel
[
"head"
],
"tail"
:
rel
[
"tail"
],
"start_index"
:
get_relation_span
(
rel
,
entities
)[
0
],
"end_index"
:
get_relation_span
(
rel
,
entities
)[
1
],
}
for
rel
in
kv_relations
],
key
=
lambda
x
:
x
[
"head"
],
)
return
relations
def
load_img
(
self
,
image_path
):
# read img
img
=
cv2
.
imread
(
image_path
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
resize_h
,
resize_w
=
self
.
img_size
im_shape
=
img
.
shape
[
0
:
2
]
im_scale_y
=
resize_h
/
im_shape
[
0
]
im_scale_x
=
resize_w
/
im_shape
[
1
]
img_new
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
2
)
mean
=
np
.
array
([
0.485
,
0.456
,
0.406
])[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
([
0.229
,
0.224
,
0.225
])[
np
.
newaxis
,
np
.
newaxis
,
:]
img_new
=
img_new
/
255.0
img_new
-=
mean
img_new
/=
std
img
=
img_new
.
transpose
((
2
,
0
,
1
))
return
img
def
__getitem__
(
self
,
idx
):
if
self
.
load_mode
==
"all"
:
data
=
copy
.
deepcopy
(
self
.
encoded_inputs_all
[
idx
])
else
:
data
=
self
.
_parse_label_file
(
self
.
all_lines
[
idx
])[
0
]
image_path
=
data
.
pop
(
'image_path'
)
data
[
"image"
]
=
self
.
load_img
(
image_path
)
return_data
=
{}
for
k
,
v
in
data
.
items
():
if
k
in
self
.
return_keys
:
if
self
.
return_keys
[
k
][
'type'
]
==
'np'
:
v
=
np
.
array
(
v
,
dtype
=
self
.
return_keys
[
k
][
'dtype'
])
return_data
[
k
]
=
v
return
return_data
def
__len__
(
self
,
):
if
self
.
load_mode
==
"all"
:
return
len
(
self
.
encoded_inputs_all
)
else
:
return
len
(
self
.
all_lines
)
def
get_relation_span
(
rel
,
entities
):
bound
=
[]
for
entity_index
in
[
rel
[
"head"
],
rel
[
"tail"
]]:
bound
.
append
(
entities
[
entity_index
][
"start"
])
bound
.
append
(
entities
[
entity_index
][
"end"
])
return
min
(
bound
),
max
(
bound
)
def
reformat
(
data
):
new_data
=
{}
for
item
in
data
:
for
k
,
v
in
item
.
items
():
if
k
not
in
new_data
:
new_data
[
k
]
=
[]
new_data
[
k
].
append
(
v
)
return
new_data
requirements.txt
View file @
a323fce6
...
...
@@ -13,4 +13,3 @@ lxml
premailer
openpyxl
fasttext
==0.9.1
paddlenlp
>=2.2.1
tools/eval.py
View file @
a323fce6
...
...
@@ -61,7 +61,8 @@ def main():
else
:
model_type
=
None
best_model_dict
=
load_model
(
config
,
model
)
best_model_dict
=
load_model
(
config
,
model
,
model_type
=
config
[
'Architecture'
][
"model_type"
])
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
...
...
tools/infer_vqa_token_ser.py
0 → 100755
View file @
a323fce6
This diff is collapsed.
Click to expand it.
tools/infer_vqa_token_ser_re.py
0 → 100755
View file @
a323fce6
This diff is collapsed.
Click to expand it.
tools/program.py
View file @
a323fce6
This diff is collapsed.
Click to expand it.
tools/train.py
View file @
a323fce6
...
...
@@ -97,7 +97,8 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
)
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
,
config
[
'Architecture'
][
"model_type"
])
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment