Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a323fce6
"test/predict_system.py" did not exist on "d721519335b84f2af16df9a582bd814a6a09a59d"
Commit
a323fce6
authored
Jan 05, 2022
by
WenmuZhou
Browse files
vqa code integrated into ppocr training system
parent
1ded2ac4
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
394 additions
and
1870 deletions
+394
-1870
ppstructure/vqa/infer_ser_e2e.py
ppstructure/vqa/infer_ser_e2e.py
+0
-156
ppstructure/vqa/infer_ser_re_e2e.py
ppstructure/vqa/infer_ser_re_e2e.py
+0
-135
ppstructure/vqa/metric.py
ppstructure/vqa/metric.py
+0
-175
ppstructure/vqa/requirements.txt
ppstructure/vqa/requirements.txt
+2
-1
ppstructure/vqa/train_re.py
ppstructure/vqa/train_re.py
+0
-229
ppstructure/vqa/train_ser.py
ppstructure/vqa/train_ser.py
+0
-248
ppstructure/vqa/vqa_utils.py
ppstructure/vqa/vqa_utils.py
+0
-400
ppstructure/vqa/xfun.py
ppstructure/vqa/xfun.py
+0
-464
requirements.txt
requirements.txt
+0
-1
tools/eval.py
tools/eval.py
+2
-1
tools/infer_vqa_token_ser.py
tools/infer_vqa_token_ser.py
+135
-0
tools/infer_vqa_token_ser_re.py
tools/infer_vqa_token_ser_re.py
+199
-0
tools/program.py
tools/program.py
+54
-59
tools/train.py
tools/train.py
+2
-1
No files found.
ppstructure/vqa/infer_ser_e2e.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
import
json
import
cv2
import
numpy
as
np
from
copy
import
deepcopy
from
PIL
import
Image
import
paddle
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
# relative reference
from
vqa_utils
import
parse_args
,
get_image_file_list
,
draw_ser_results
,
get_bio_label_maps
from
vqa_utils
import
pad_sentences
,
split_page
,
preprocess
,
postprocess
,
merge_preds_list_with_ocr_info
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
trans_poly_to_bbox
(
poly
):
x1
=
np
.
min
([
p
[
0
]
for
p
in
poly
])
x2
=
np
.
max
([
p
[
0
]
for
p
in
poly
])
y1
=
np
.
min
([
p
[
1
]
for
p
in
poly
])
y2
=
np
.
max
([
p
[
1
]
for
p
in
poly
])
return
[
x1
,
y1
,
x2
,
y2
]
def
parse_ocr_info_for_ser
(
ocr_result
):
ocr_info
=
[]
for
res
in
ocr_result
:
ocr_info
.
append
({
"text"
:
res
[
1
][
0
],
"bbox"
:
trans_poly_to_bbox
(
res
[
0
]),
"poly"
:
res
[
0
],
})
return
ocr_info
class
SerPredictor
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
max_seq_length
=
args
.
max_seq_length
# init ser token and model
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
self
.
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
.
eval
()
# init ocr_engine
from
paddleocr
import
PaddleOCR
self
.
ocr_engine
=
PaddleOCR
(
rec_model_dir
=
args
.
rec_model_dir
,
det_model_dir
=
args
.
det_model_dir
,
use_angle_cls
=
False
,
show_log
=
False
)
# init dict
label2id_map
,
self
.
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
self
.
label2id_map_for_draw
=
dict
()
for
key
in
label2id_map
:
if
key
.
startswith
(
"I-"
):
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
"B"
+
key
[
1
:]]
else
:
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
key
]
def
__call__
(
self
,
img
):
ocr_result
=
self
.
ocr_engine
.
ocr
(
img
,
cls
=
False
)
ocr_info
=
parse_ocr_info_for_ser
(
ocr_result
)
inputs
=
preprocess
(
tokenizer
=
self
.
tokenizer
,
ori_img
=
img
,
ocr_info
=
ocr_info
,
max_seq_len
=
self
.
max_seq_length
)
if
self
.
args
.
ser_model_type
==
'LayoutLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
elif
self
.
args
.
ser_model_type
==
'LayoutXLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
preds
=
preds
[
0
]
preds
=
postprocess
(
inputs
[
"attention_mask"
],
preds
,
self
.
id2label_map
)
ocr_info
=
merge_preds_list_with_ocr_info
(
ocr_info
,
inputs
[
"segment_offset_id"
],
preds
,
self
.
label2id_map_for_draw
)
return
ocr_info
,
inputs
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# get infer img list
infer_imgs
=
get_image_file_list
(
args
.
infer_imgs
)
# loop for infer
ser_engine
=
SerPredictor
(
args
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
print
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
result
,
_
=
ser_engine
(
img
)
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ser_resule"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
ppstructure/vqa/infer_ser_re_e2e.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
json
import
cv2
import
numpy
as
np
from
copy
import
deepcopy
from
PIL
import
Image
import
paddle
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForRelationExtraction
# relative reference
from
vqa_utils
import
parse_args
,
get_image_file_list
,
draw_re_results
from
infer_ser_e2e
import
SerPredictor
def
make_input
(
ser_input
,
ser_result
,
max_seq_len
=
512
):
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
entities
=
ser_input
[
'entities'
][
0
]
assert
len
(
entities
)
==
len
(
ser_result
)
# entities
start
=
[]
end
=
[]
label
=
[]
entity_idx_dict
=
{}
for
i
,
(
res
,
entity
)
in
enumerate
(
zip
(
ser_result
,
entities
)):
if
res
[
'pred'
]
==
'O'
:
continue
entity_idx_dict
[
len
(
start
)]
=
i
start
.
append
(
entity
[
'start'
])
end
.
append
(
entity
[
'end'
])
label
.
append
(
entities_labels
[
res
[
'pred'
]])
entities
=
dict
(
start
=
start
,
end
=
end
,
label
=
label
)
# relations
head
=
[]
tail
=
[]
for
i
in
range
(
len
(
entities
[
"label"
])):
for
j
in
range
(
len
(
entities
[
"label"
])):
if
entities
[
"label"
][
i
]
==
1
and
entities
[
"label"
][
j
]
==
2
:
head
.
append
(
i
)
tail
.
append
(
j
)
relations
=
dict
(
head
=
head
,
tail
=
tail
)
batch_size
=
ser_input
[
"input_ids"
].
shape
[
0
]
entities_batch
=
[]
relations_batch
=
[]
for
b
in
range
(
batch_size
):
entities_batch
.
append
(
entities
)
relations_batch
.
append
(
relations
)
ser_input
[
'entities'
]
=
entities_batch
ser_input
[
'relations'
]
=
relations_batch
ser_input
.
pop
(
'segment_offset_id'
)
return
ser_input
,
entity_idx_dict
class
SerReSystem
(
object
):
def
__init__
(
self
,
args
):
self
.
ser_engine
=
SerPredictor
(
args
)
self
.
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
re_model_name_or_path
)
self
.
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
re_model_name_or_path
)
self
.
model
.
eval
()
def
__call__
(
self
,
img
):
ser_result
,
ser_inputs
=
self
.
ser_engine
(
img
)
re_input
,
entity_idx_dict
=
make_input
(
ser_inputs
,
ser_result
)
re_result
=
self
.
model
(
**
re_input
)
pred_relations
=
re_result
[
'pred_relations'
][
0
]
# 进行 relations 到 ocr信息的转换
result
=
[]
used_tail_id
=
[]
for
relation
in
pred_relations
:
if
relation
[
'tail_id'
]
in
used_tail_id
:
continue
used_tail_id
.
append
(
relation
[
'tail_id'
])
ocr_info_head
=
ser_result
[
entity_idx_dict
[
relation
[
'head_id'
]]]
ocr_info_tail
=
ser_result
[
entity_idx_dict
[
relation
[
'tail_id'
]]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
return
result
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# get infer img list
infer_imgs
=
get_image_file_list
(
args
.
infer_imgs
)
# loop for infer
ser_re_engine
=
SerReSystem
(
args
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_re.jpg"
)
print
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
result
=
ser_re_engine
(
img
)
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"result"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_re_results
(
img
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
ppstructure/vqa/metric.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
numpy
as
np
import
logging
logger
=
logging
.
getLogger
(
__name__
)
PREFIX_CHECKPOINT_DIR
=
"checkpoint"
_re_checkpoint
=
re
.
compile
(
r
"^"
+
PREFIX_CHECKPOINT_DIR
+
r
"\-(\d+)$"
)
def
get_last_checkpoint
(
folder
):
content
=
os
.
listdir
(
folder
)
checkpoints
=
[
path
for
path
in
content
if
_re_checkpoint
.
search
(
path
)
is
not
None
and
os
.
path
.
isdir
(
os
.
path
.
join
(
folder
,
path
))
]
if
len
(
checkpoints
)
==
0
:
return
return
os
.
path
.
join
(
folder
,
max
(
checkpoints
,
key
=
lambda
x
:
int
(
_re_checkpoint
.
search
(
x
).
groups
()[
0
])))
def
re_score
(
pred_relations
,
gt_relations
,
mode
=
"strict"
):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert
mode
in
[
"strict"
,
"boundaries"
]
relation_types
=
[
v
for
v
in
[
0
,
1
]
if
not
v
==
0
]
scores
=
{
rel
:
{
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
}
for
rel
in
relation_types
+
[
"ALL"
]
}
# Count GT relations and Predicted relations
n_sents
=
len
(
gt_relations
)
n_rels
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
gt_relations
])
n_found
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
pred_relations
])
# Count TP, FP and FN per type
for
pred_sent
,
gt_sent
in
zip
(
pred_relations
,
gt_relations
):
for
rel_type
in
relation_types
:
# strict mode takes argument types into account
if
mode
==
"strict"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
# boundaries mode only takes argument spans into account
elif
mode
==
"boundaries"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
scores
[
rel_type
][
"tp"
]
+=
len
(
pred_rels
&
gt_rels
)
scores
[
rel_type
][
"fp"
]
+=
len
(
pred_rels
-
gt_rels
)
scores
[
rel_type
][
"fn"
]
+=
len
(
gt_rels
-
pred_rels
)
# Compute per entity Precision / Recall / F1
for
rel_type
in
scores
.
keys
():
if
scores
[
rel_type
][
"tp"
]:
scores
[
rel_type
][
"p"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fp"
]
+
scores
[
rel_type
][
"tp"
])
scores
[
rel_type
][
"r"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fn"
]
+
scores
[
rel_type
][
"tp"
])
else
:
scores
[
rel_type
][
"p"
],
scores
[
rel_type
][
"r"
]
=
0
,
0
if
not
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]
==
0
:
scores
[
rel_type
][
"f1"
]
=
(
2
*
scores
[
rel_type
][
"p"
]
*
scores
[
rel_type
][
"r"
]
/
(
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]))
else
:
scores
[
rel_type
][
"f1"
]
=
0
# Compute micro F1 Scores
tp
=
sum
([
scores
[
rel_type
][
"tp"
]
for
rel_type
in
relation_types
])
fp
=
sum
([
scores
[
rel_type
][
"fp"
]
for
rel_type
in
relation_types
])
fn
=
sum
([
scores
[
rel_type
][
"fn"
]
for
rel_type
in
relation_types
])
if
tp
:
precision
=
tp
/
(
tp
+
fp
)
recall
=
tp
/
(
tp
+
fn
)
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
else
:
precision
,
recall
,
f1
=
0
,
0
,
0
scores
[
"ALL"
][
"p"
]
=
precision
scores
[
"ALL"
][
"r"
]
=
recall
scores
[
"ALL"
][
"f1"
]
=
f1
scores
[
"ALL"
][
"tp"
]
=
tp
scores
[
"ALL"
][
"fp"
]
=
fp
scores
[
"ALL"
][
"fn"
]
=
fn
# Compute Macro F1 Scores
scores
[
"ALL"
][
"Macro_f1"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"f1"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_p"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"p"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_r"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"r"
]
for
ent_type
in
relation_types
])
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
# logger.info(
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
# n_sents, n_rels, n_found, tp
# )
# )
# logger.info(
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
# )
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
# logger.info(
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
# )
# )
# for rel_type in relation_types:
# logger.info(
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
# rel_type,
# scores[rel_type]["tp"],
# scores[rel_type]["fp"],
# scores[rel_type]["fn"],
# scores[rel_type]["p"],
# scores[rel_type]["r"],
# scores[rel_type]["f1"],
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
# )
# )
return
scores
ppstructure/vqa/requirements.txt
View file @
a323fce6
sentencepiece
sentencepiece
yacs
yacs
seqeval
seqeval
\ No newline at end of file
paddlenlp>=2.2.1
\ No newline at end of file
ppstructure/vqa/train_re.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
time
import
numpy
as
np
import
paddle
from
paddlenlp.transformers
import
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForRelationExtraction
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
,
set_seed
from
data_collator
import
DataCollator
from
eval_re
import
evaluate
from
ppocr.utils.logging
import
get_logger
def
train
(
args
):
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
print_arguments
(
args
,
logger
)
# Added here for reproducibility (even between python 2 and 3)
set_seed
(
args
.
seed
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
# dist mode
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
if
not
args
.
resume
:
model
=
LayoutXLMModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
LayoutXLMForRelationExtraction
(
model
,
dropout
=
None
)
logger
.
info
(
'train from scratch'
)
else
:
logger
.
info
(
'resume from {}'
.
format
(
args
.
model_name_or_path
))
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
model_name_or_path
)
# dist mode
if
distributed
:
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
train_data_dir
,
label_path
=
args
.
train_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
DataCollator
())
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
shuffle
=
False
,
collate_fn
=
DataCollator
())
t_total
=
len
(
train_dataloader
)
*
args
.
num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
args
.
learning_rate
,
decay_steps
=
t_total
,
end_lr
=
0.0
,
power
=
1.0
)
if
args
.
warmup_steps
>
0
:
lr_scheduler
=
paddle
.
optimizer
.
lr
.
LinearWarmup
(
lr_scheduler
,
args
.
warmup_steps
,
start_lr
=
0
,
end_lr
=
args
.
learning_rate
,
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
10
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
,
parameters
=
model
.
parameters
(),
epsilon
=
args
.
adam_epsilon
,
grad_clip
=
grad_clip
,
weight_decay
=
args
.
weight_decay
)
# Train!
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = {}"
.
format
(
len
(
train_dataset
)))
logger
.
info
(
" Num Epochs = {}"
.
format
(
args
.
num_train_epochs
))
logger
.
info
(
" Instantaneous batch size per GPU = {}"
.
format
(
args
.
per_gpu_train_batch_size
))
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = {}"
.
format
(
args
.
per_gpu_train_batch_size
*
paddle
.
distributed
.
get_world_size
()))
logger
.
info
(
" Total optimization steps = {}"
.
format
(
t_total
))
global_step
=
0
model
.
clear_gradients
()
train_dataloader_len
=
len
(
train_dataloader
)
best_metirc
=
{
'f1'
:
0
}
model
.
train
()
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
print_step
=
1
for
epoch
in
range
(
int
(
args
.
num_train_epochs
)):
for
step
,
batch
in
enumerate
(
train_dataloader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
train_start
=
time
.
time
()
outputs
=
model
(
**
batch
)
train_run_cost
+=
time
.
time
()
-
train_start
# model outputs are always tuple in ppnlp (see doc)
loss
=
outputs
[
'loss'
]
loss
=
loss
.
mean
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
# lr_scheduler.step() # Update learning rate schedule
global_step
+=
1
total_samples
+=
batch
[
'image'
].
shape
[
0
]
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch
,
args
.
num_train_epochs
,
step
,
train_dataloader_len
,
global_step
,
np
.
mean
(
loss
.
numpy
()),
optimizer
.
get_lr
(),
train_reader_cost
/
print_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_step
,
total_samples
/
print_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
)))
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results
=
evaluate
(
model
,
eval_dataloader
,
logger
)
if
results
[
'f1'
]
>=
best_metirc
[
'f1'
]:
best_metirc
=
results
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"eval results: {}"
.
format
(
results
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
train
(
args
)
ppstructure/vqa/train_ser.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
time
import
copy
import
logging
import
argparse
import
paddle
import
numpy
as
np
from
seqeval.metrics
import
classification_report
,
f1_score
,
precision_score
,
recall_score
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
,
set_seed
from
eval_ser
import
evaluate
from
losses
import
SERLoss
from
ppocr.utils.logging
import
get_logger
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
train
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
print_arguments
(
args
,
logger
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
loss_class
=
SERLoss
(
len
(
label2id_map
))
pad_token_label_id
=
loss_class
.
ignore_index
# dist mode
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
if
not
args
.
resume
:
base_model
=
base_model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
(
base_model
,
num_classes
=
len
(
label2id_map
),
dropout
=
None
)
logger
.
info
(
'train from scratch'
)
else
:
logger
.
info
(
'resume from {}'
.
format
(
args
.
model_name_or_path
))
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
# dist mode
if
distributed
:
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
train_data_dir
,
label_path
=
args
.
train_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
False
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
False
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
None
,
)
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
None
,
)
t_total
=
len
(
train_dataloader
)
*
args
.
num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
args
.
learning_rate
,
decay_steps
=
t_total
,
end_lr
=
0.0
,
power
=
1.0
)
if
args
.
warmup_steps
>
0
:
lr_scheduler
=
paddle
.
optimizer
.
lr
.
LinearWarmup
(
lr_scheduler
,
args
.
warmup_steps
,
start_lr
=
0
,
end_lr
=
args
.
learning_rate
,
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
parameters
=
model
.
parameters
(),
epsilon
=
args
.
adam_epsilon
,
weight_decay
=
args
.
weight_decay
)
# Train!
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed) = %d"
,
args
.
per_gpu_train_batch_size
*
paddle
.
distributed
.
get_world_size
(),
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
tr_loss
=
0.0
set_seed
(
args
.
seed
)
best_metrics
=
None
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
print_step
=
1
model
.
train
()
for
epoch_id
in
range
(
args
.
num_train_epochs
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
if
args
.
ser_model_type
==
'LayoutLM'
:
if
'image'
in
batch
:
batch
.
pop
(
'image'
)
labels
=
batch
.
pop
(
'labels'
)
train_start
=
time
.
time
()
outputs
=
model
(
**
batch
)
train_run_cost
+=
time
.
time
()
-
train_start
if
args
.
ser_model_type
==
'LayoutXLM'
:
outputs
=
outputs
[
0
]
loss
=
loss_class
(
labels
,
outputs
,
batch
[
'attention_mask'
])
# model outputs are always tuple in ppnlp (see doc)
loss
=
loss
.
mean
()
loss
.
backward
()
tr_loss
+=
loss
.
item
()
optimizer
.
step
()
lr_scheduler
.
step
()
# Update learning rate schedule
optimizer
.
clear_grad
()
global_step
+=
1
total_samples
+=
batch
[
'input_ids'
].
shape
[
0
]
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
global_step
,
loss
.
numpy
()[
0
],
lr_scheduler
.
get_lr
(),
train_reader_cost
/
print_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_step
,
total_samples
/
print_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
)))
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
)
if
best_metrics
is
None
or
results
[
"f1"
]
>=
best_metrics
[
"f1"
]:
best_metrics
=
copy
.
deepcopy
(
results
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"[epoch {}/{}][iter: {}/{}] results: {}"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
results
))
if
best_metrics
is
not
None
:
logger
.
info
(
"best metrics: {}"
.
format
(
best_metrics
))
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
return
global_step
,
tr_loss
/
global_step
if
__name__
==
"__main__"
:
args
=
parse_args
()
train
(
args
)
ppstructure/vqa/vqa_utils.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
cv2
import
random
import
numpy
as
np
import
imghdr
from
copy
import
deepcopy
import
paddle
from
PIL
import
Image
,
ImageDraw
,
ImageFont
def
set_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
def
get_bio_label_maps
(
label_map_path
):
with
open
(
label_map_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
lines
=
[
line
.
strip
()
for
line
in
lines
]
if
"O"
not
in
lines
:
lines
.
insert
(
0
,
"O"
)
labels
=
[]
for
line
in
lines
:
if
line
==
"O"
:
labels
.
append
(
"O"
)
else
:
labels
.
append
(
"B-"
+
line
)
labels
.
append
(
"I-"
+
line
)
label2id_map
=
{
label
:
idx
for
idx
,
label
in
enumerate
(
labels
)}
id2label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
labels
)}
return
label2id_map
,
id2label_map
def
get_image_file_list
(
img_file
):
imgs_lists
=
[]
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'GIF'
}
if
os
.
path
.
isfile
(
img_file
)
and
imghdr
.
what
(
img_file
)
in
img_end
:
imgs_lists
.
append
(
img_file
)
elif
os
.
path
.
isdir
(
img_file
):
for
single_file
in
os
.
listdir
(
img_file
):
file_path
=
os
.
path
.
join
(
img_file
,
single_file
)
if
os
.
path
.
isfile
(
file_path
)
and
imghdr
.
what
(
file_path
)
in
img_end
:
imgs_lists
.
append
(
file_path
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
imgs_lists
=
sorted
(
imgs_lists
)
return
imgs_lists
def
draw_ser_results
(
image
,
ocr_results
,
font_path
=
"../../doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
2021
)
color
=
(
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)))
color_map
=
{
idx
:
(
color
[
0
][
idx
],
color
[
1
][
idx
],
color
[
2
][
idx
])
for
idx
in
range
(
1
,
255
)
}
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
for
ocr_info
in
ocr_results
:
if
ocr_info
[
"pred_id"
]
not
in
color_map
:
continue
color
=
color_map
[
ocr_info
[
"pred_id"
]]
text
=
"{}: {}"
.
format
(
ocr_info
[
"pred"
],
ocr_info
[
"text"
])
draw_box_txt
(
ocr_info
[
"bbox"
],
text
,
draw
,
font
,
font_size
,
color
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
def
draw_box_txt
(
bbox
,
text
,
draw
,
font
,
font_size
,
color
):
# draw ocr results outline
bbox
=
((
bbox
[
0
],
bbox
[
1
]),
(
bbox
[
2
],
bbox
[
3
]))
draw
.
rectangle
(
bbox
,
fill
=
color
)
# draw ocr results
start_y
=
max
(
0
,
bbox
[
0
][
1
]
-
font_size
)
tw
=
font
.
getsize
(
text
)[
0
]
draw
.
rectangle
(
[(
bbox
[
0
][
0
]
+
1
,
start_y
),
(
bbox
[
0
][
0
]
+
tw
+
1
,
start_y
+
font_size
)],
fill
=
(
0
,
0
,
255
))
draw
.
text
((
bbox
[
0
][
0
]
+
1
,
start_y
),
text
,
fill
=
(
255
,
255
,
255
),
font
=
font
)
def
draw_re_results
(
image
,
result
,
font_path
=
"../../doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
0
)
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
color_head
=
(
0
,
0
,
255
)
color_tail
=
(
255
,
0
,
0
)
color_line
=
(
0
,
255
,
0
)
for
ocr_info_head
,
ocr_info_tail
in
result
:
draw_box_txt
(
ocr_info_head
[
"bbox"
],
ocr_info_head
[
"text"
],
draw
,
font
,
font_size
,
color_head
)
draw_box_txt
(
ocr_info_tail
[
"bbox"
],
ocr_info_tail
[
"text"
],
draw
,
font
,
font_size
,
color_tail
)
center_head
=
(
(
ocr_info_head
[
'bbox'
][
0
]
+
ocr_info_head
[
'bbox'
][
2
])
//
2
,
(
ocr_info_head
[
'bbox'
][
1
]
+
ocr_info_head
[
'bbox'
][
3
])
//
2
)
center_tail
=
(
(
ocr_info_tail
[
'bbox'
][
0
]
+
ocr_info_tail
[
'bbox'
][
2
])
//
2
,
(
ocr_info_tail
[
'bbox'
][
1
]
+
ocr_info_tail
[
'bbox'
][
3
])
//
2
)
draw
.
line
([
center_head
,
center_tail
],
fill
=
color_line
,
width
=
5
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
# pad sentences
def
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
# Padding with larger size, reshape is carried out
max_seq_len
=
(
len
(
encoded_inputs
[
"input_ids"
])
//
max_seq_len
+
1
)
*
max_seq_len
needs_to_be_padded
=
pad_to_max_seq_len
and
\
max_seq_len
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_seq_len
if
needs_to_be_padded
:
difference
=
max_seq_len
-
len
(
encoded_inputs
[
"input_ids"
])
if
tokenizer
.
padding_side
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
encoded_inputs
[
"token_type_ids"
]
+
[
tokenizer
.
pad_token_type_id
]
*
difference
)
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
tokenizer
.
pad_token_id
]
*
difference
encoded_inputs
[
"bbox"
]
=
encoded_inputs
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]
]
*
difference
else
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
return
encoded_inputs
def
split_page
(
encoded_inputs
,
max_seq_len
=
512
):
"""
truncate is often used in training process
"""
for
key
in
encoded_inputs
:
if
key
==
'entities'
:
encoded_inputs
[
key
]
=
[
encoded_inputs
[
key
]]
continue
encoded_inputs
[
key
]
=
paddle
.
to_tensor
(
encoded_inputs
[
key
])
if
encoded_inputs
[
key
].
ndim
<=
1
:
# for input_ids, att_mask and so on
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
([
-
1
,
max_seq_len
])
else
:
# for bbox
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
(
[
-
1
,
max_seq_len
,
4
])
return
encoded_inputs
def
preprocess
(
tokenizer
,
ori_img
,
ocr_info
,
img_size
=
(
224
,
224
),
pad_token_label_id
=-
100
,
max_seq_len
=
512
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
):
ocr_info
=
deepcopy
(
ocr_info
)
height
=
ori_img
.
shape
[
0
]
width
=
ori_img
.
shape
[
1
]
img
=
cv2
.
resize
(
ori_img
,
img_size
).
transpose
([
2
,
0
,
1
]).
astype
(
np
.
float32
)
segment_offset_id
=
[]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
entities
=
[]
for
info
in
ocr_info
:
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
if
not
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
# for re
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
"O"
,
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
words_list
.
append
(
text
)
segment_offset_id
.
append
(
len
(
input_ids_list
))
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
"entities"
:
entities
}
encoded_inputs
=
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
max_seq_len
,
return_attention_mask
=
return_attention_mask
)
encoded_inputs
=
split_page
(
encoded_inputs
)
fake_bs
=
encoded_inputs
[
"input_ids"
].
shape
[
0
]
encoded_inputs
[
"image"
]
=
paddle
.
to_tensor
(
img
).
unsqueeze
(
0
).
expand
(
[
fake_bs
]
+
list
(
img
.
shape
))
encoded_inputs
[
"segment_offset_id"
]
=
segment_offset_id
return
encoded_inputs
def
postprocess
(
attention_mask
,
preds
,
id2label_map
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
preds_list
=
[[]
for
_
in
range
(
preds
.
shape
[
0
])]
# keep batch info
for
i
in
range
(
preds
.
shape
[
0
]):
for
j
in
range
(
preds
.
shape
[
1
]):
if
attention_mask
[
i
][
j
]
==
1
:
preds_list
[
i
].
append
(
id2label_map
[
preds
[
i
][
j
]])
return
preds_list
def
merge_preds_list_with_ocr_info
(
ocr_info
,
segment_offset_id
,
preds_list
,
label2id_map_for_draw
):
# must ensure the preds_list is generated from the same image
preds
=
[
p
for
pred
in
preds_list
for
p
in
pred
]
id2label_map
=
dict
()
for
key
in
label2id_map_for_draw
:
val
=
label2id_map_for_draw
[
key
]
if
key
==
"O"
:
id2label_map
[
val
]
=
key
if
key
.
startswith
(
"B-"
)
or
key
.
startswith
(
"I-"
):
id2label_map
[
val
]
=
key
[
2
:]
else
:
id2label_map
[
val
]
=
key
for
idx
in
range
(
len
(
segment_offset_id
)):
if
idx
==
0
:
start_id
=
0
else
:
start_id
=
segment_offset_id
[
idx
-
1
]
end_id
=
segment_offset_id
[
idx
]
curr_pred
=
preds
[
start_id
:
end_id
]
curr_pred
=
[
label2id_map_for_draw
[
p
]
for
p
in
curr_pred
]
if
len
(
curr_pred
)
<=
0
:
pred_id
=
0
else
:
counts
=
np
.
bincount
(
curr_pred
)
pred_id
=
np
.
argmax
(
counts
)
ocr_info
[
idx
][
"pred_id"
]
=
int
(
pred_id
)
ocr_info
[
idx
][
"pred"
]
=
id2label_map
[
int
(
pred_id
)]
return
ocr_info
def
print_arguments
(
args
,
logger
=
None
):
print_func
=
logger
.
info
if
logger
is
not
None
else
print
"""print arguments"""
print_func
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print_func
(
'%s: %s'
%
(
arg
,
value
))
print_func
(
'------------------------------------------------'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
# Required parameters
# yapf: disable
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,)
parser
.
add_argument
(
"--ser_model_type"
,
default
=
'LayoutXLM'
,
type
=
str
)
parser
.
add_argument
(
"--re_model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--train_data_dir"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--train_label_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--eval_data_dir"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--eval_label_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
512
,
type
=
int
,)
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
"store_true"
,)
parser
.
add_argument
(
"--num_workers"
,
default
=
8
,
type
=
int
,)
parser
.
add_argument
(
"--per_gpu_train_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
,)
parser
.
add_argument
(
"--per_gpu_eval_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU/CPU for eval."
,)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
,)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
,)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
,)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
,)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
,)
parser
.
add_argument
(
"--eval_steps"
,
type
=
int
,
default
=
10
,
help
=
"eval every X updates steps."
,)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
2048
,
help
=
"random seed for initialization"
,)
parser
.
add_argument
(
"--rec_model_dir"
,
default
=
None
,
type
=
str
,
)
parser
.
add_argument
(
"--det_model_dir"
,
default
=
None
,
type
=
str
,
)
parser
.
add_argument
(
"--label_map_path"
,
default
=
"./labels/labels_ser.txt"
,
type
=
str
,
required
=
False
,
)
parser
.
add_argument
(
"--infer_imgs"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--resume"
,
action
=
'store_true'
)
parser
.
add_argument
(
"--ocr_json_path"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"ocr prediction results"
)
# yapf: enable
args
=
parser
.
parse_args
()
return
args
ppstructure/vqa/xfun.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
cv2
import
numpy
as
np
import
paddle
import
copy
from
paddle.io
import
Dataset
__all__
=
[
"XFUNDataset"
]
class
XFUNDataset
(
Dataset
):
"""
Example:
print("=====begin to build dataset=====")
from paddlenlp.transformers import LayoutXLMTokenizer
tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
tok_res = tokenizer.tokenize("Maribyrnong")
# res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
dataset = XfunDatasetForSer(
tokenizer,
data_dir="./zh.val/",
label_path="zh.val/xfun_normalize_val.json",
img_size=(224,224))
print(len(dataset))
data = dataset[0]
print(data.keys())
print("input_ids: ", data["input_ids"])
print("labels: ", data["labels"])
print("token_type_ids: ", data["token_type_ids"])
print("words_list: ", data["words_list"])
print("image shape: ", data["image"].shape)
"""
def
__init__
(
self
,
tokenizer
,
data_dir
,
label_path
,
contains_re
=
False
,
label2id_map
=
None
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
None
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
,
max_seq_len
=
512
):
super
().
__init__
()
self
.
tokenizer
=
tokenizer
self
.
data_dir
=
data_dir
self
.
label_path
=
label_path
self
.
contains_re
=
contains_re
self
.
label2id_map
=
label2id_map
self
.
img_size
=
img_size
self
.
pad_token_label_id
=
pad_token_label_id
self
.
add_special_ids
=
add_special_ids
self
.
return_attention_mask
=
return_attention_mask
self
.
load_mode
=
load_mode
self
.
max_seq_len
=
max_seq_len
if
self
.
pad_token_label_id
is
None
:
self
.
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
self
.
all_lines
=
self
.
read_all_lines
()
self
.
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
self
.
return_keys
=
{
'bbox'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'input_ids'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'labels'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'attention_mask'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'image'
:
{
'type'
:
'np'
,
'dtype'
:
'float32'
},
'token_type_ids'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'entities'
:
{
'type'
:
'dict'
},
'relations'
:
{
'type'
:
'dict'
}
}
if
load_mode
==
"all"
:
self
.
encoded_inputs_all
=
self
.
_parse_label_file_all
()
def
pad_sentences
(
self
,
encoded_inputs
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
truncation_strategy
=
"longest_first"
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
# Padding
needs_to_be_padded
=
pad_to_max_seq_len
and
\
max_seq_len
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_seq_len
if
needs_to_be_padded
:
difference
=
max_seq_len
-
len
(
encoded_inputs
[
"input_ids"
])
if
self
.
tokenizer
.
padding_side
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
encoded_inputs
[
"token_type_ids"
]
+
[
self
.
tokenizer
.
pad_token_type_id
]
*
difference
)
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
self
.
tokenizer
.
pad_token_id
]
*
difference
encoded_inputs
[
"labels"
]
=
encoded_inputs
[
"labels"
]
+
[
self
.
pad_token_label_id
]
*
difference
encoded_inputs
[
"bbox"
]
=
encoded_inputs
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]]
*
difference
elif
self
.
tokenizer
.
padding_side
==
'left'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
0
]
*
difference
+
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
[
self
.
tokenizer
.
pad_token_type_id
]
*
difference
+
encoded_inputs
[
"token_type_ids"
])
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
[
1
]
*
difference
+
encoded_inputs
[
"special_tokens_mask"
]
encoded_inputs
[
"input_ids"
]
=
[
self
.
tokenizer
.
pad_token_id
]
*
difference
+
encoded_inputs
[
"input_ids"
]
encoded_inputs
[
"labels"
]
=
[
self
.
pad_token_label_id
]
*
difference
+
encoded_inputs
[
"labels"
]
encoded_inputs
[
"bbox"
]
=
[
[
0
,
0
,
0
,
0
]
]
*
difference
+
encoded_inputs
[
"bbox"
]
else
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
return
encoded_inputs
def
truncate_inputs
(
self
,
encoded_inputs
,
max_seq_len
=
512
):
for
key
in
encoded_inputs
:
if
key
==
"sample_id"
:
continue
length
=
min
(
len
(
encoded_inputs
[
key
]),
max_seq_len
)
encoded_inputs
[
key
]
=
encoded_inputs
[
key
][:
length
]
return
encoded_inputs
def
read_all_lines
(
self
,
):
with
open
(
self
.
label_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
return
lines
def
_parse_label_file_all
(
self
):
"""
parse all samples
"""
encoded_inputs_all
=
[]
for
line
in
self
.
all_lines
:
encoded_inputs_all
.
extend
(
self
.
_parse_label_file
(
line
))
return
encoded_inputs_all
def
_parse_label_file
(
self
,
line
):
"""
parse single sample
"""
image_name
,
info_str
=
line
.
split
(
"
\t
"
)
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
image_name
)
def
add_imgge_path
(
x
):
x
[
'image_path'
]
=
image_path
return
x
encoded_inputs
=
self
.
_read_encoded_inputs_sample
(
info_str
)
if
self
.
contains_re
:
encoded_inputs
=
self
.
_chunk_re
(
encoded_inputs
)
else
:
encoded_inputs
=
self
.
_chunk_ser
(
encoded_inputs
)
encoded_inputs
=
list
(
map
(
add_imgge_path
,
encoded_inputs
))
return
encoded_inputs
def
_read_encoded_inputs_sample
(
self
,
info_str
):
"""
parse label info
"""
# read text info
info_dict
=
json
.
loads
(
info_str
)
height
=
info_dict
[
"height"
]
width
=
info_dict
[
"width"
]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
gt_label_list
=
[]
if
self
.
contains_re
:
# for re
entities
=
[]
relations
=
[]
id2label
=
{}
entity_id_to_index_map
=
{}
empty_entity
=
set
()
for
info
in
info_dict
[
"ocr_info"
]:
if
self
.
contains_re
:
# for re
if
len
(
info
[
"text"
])
==
0
:
empty_entity
.
add
(
info
[
"id"
])
continue
id2label
[
info
[
"id"
]]
=
info
[
"label"
]
relations
.
extend
([
tuple
(
sorted
(
l
))
for
l
in
info
[
"linking"
]])
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
label
=
info
[
"label"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
self
.
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
gt_label
=
[]
if
not
self
.
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
if
label
.
lower
()
==
"other"
:
gt_label
.
extend
([
0
]
*
len
(
encode_res
[
"input_ids"
]))
else
:
gt_label
.
append
(
self
.
label2id_map
[(
"b-"
+
label
).
upper
()])
gt_label
.
extend
([
self
.
label2id_map
[(
"i-"
+
label
).
upper
()]]
*
(
len
(
encode_res
[
"input_ids"
])
-
1
))
if
self
.
contains_re
:
if
gt_label
[
0
]
!=
self
.
label2id_map
[
"O"
]:
entity_id_to_index_map
[
info
[
"id"
]]
=
len
(
entities
)
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
label
.
upper
(),
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
gt_label_list
.
extend
(
gt_label
)
words_list
.
append
(
text
)
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"labels"
:
gt_label_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
# "words_list": words_list,
}
encoded_inputs
=
self
.
pad_sentences
(
encoded_inputs
,
max_seq_len
=
self
.
max_seq_len
,
return_attention_mask
=
self
.
return_attention_mask
)
encoded_inputs
=
self
.
truncate_inputs
(
encoded_inputs
)
if
self
.
contains_re
:
relations
=
self
.
_relations
(
entities
,
relations
,
id2label
,
empty_entity
,
entity_id_to_index_map
)
encoded_inputs
[
'relations'
]
=
relations
encoded_inputs
[
'entities'
]
=
entities
return
encoded_inputs
def
_chunk_ser
(
self
,
encoded_inputs
):
encoded_inputs_all
=
[]
seq_len
=
len
(
encoded_inputs
[
'input_ids'
])
chunk_size
=
512
for
chunk_id
,
index
in
enumerate
(
range
(
0
,
seq_len
,
chunk_size
)):
chunk_beg
=
index
chunk_end
=
min
(
index
+
chunk_size
,
seq_len
)
encoded_inputs_example
=
{}
for
key
in
encoded_inputs
:
encoded_inputs_example
[
key
]
=
encoded_inputs
[
key
][
chunk_beg
:
chunk_end
]
encoded_inputs_all
.
append
(
encoded_inputs_example
)
return
encoded_inputs_all
def
_chunk_re
(
self
,
encoded_inputs
):
# prepare data
entities
=
encoded_inputs
.
pop
(
'entities'
)
relations
=
encoded_inputs
.
pop
(
'relations'
)
encoded_inputs_all
=
[]
chunk_size
=
512
for
chunk_id
,
index
in
enumerate
(
range
(
0
,
len
(
encoded_inputs
[
"input_ids"
]),
chunk_size
)):
item
=
{}
for
k
in
encoded_inputs
:
item
[
k
]
=
encoded_inputs
[
k
][
index
:
index
+
chunk_size
]
# select entity in current chunk
entities_in_this_span
=
[]
global_to_local_map
=
{}
#
for
entity_id
,
entity
in
enumerate
(
entities
):
if
(
index
<=
entity
[
"start"
]
<
index
+
chunk_size
and
index
<=
entity
[
"end"
]
<
index
+
chunk_size
):
entity
[
"start"
]
=
entity
[
"start"
]
-
index
entity
[
"end"
]
=
entity
[
"end"
]
-
index
global_to_local_map
[
entity_id
]
=
len
(
entities_in_this_span
)
entities_in_this_span
.
append
(
entity
)
# select relations in current chunk
relations_in_this_span
=
[]
for
relation
in
relations
:
if
(
index
<=
relation
[
"start_index"
]
<
index
+
chunk_size
and
index
<=
relation
[
"end_index"
]
<
index
+
chunk_size
):
relations_in_this_span
.
append
({
"head"
:
global_to_local_map
[
relation
[
"head"
]],
"tail"
:
global_to_local_map
[
relation
[
"tail"
]],
"start_index"
:
relation
[
"start_index"
]
-
index
,
"end_index"
:
relation
[
"end_index"
]
-
index
,
})
item
.
update
({
"entities"
:
reformat
(
entities_in_this_span
),
"relations"
:
reformat
(
relations_in_this_span
),
})
item
[
'entities'
][
'label'
]
=
[
self
.
entities_labels
[
x
]
for
x
in
item
[
'entities'
][
'label'
]
]
encoded_inputs_all
.
append
(
item
)
return
encoded_inputs_all
def
_relations
(
self
,
entities
,
relations
,
id2label
,
empty_entity
,
entity_id_to_index_map
):
"""
build relations
"""
relations
=
list
(
set
(
relations
))
relations
=
[
rel
for
rel
in
relations
if
rel
[
0
]
not
in
empty_entity
and
rel
[
1
]
not
in
empty_entity
]
kv_relations
=
[]
for
rel
in
relations
:
pair
=
[
id2label
[
rel
[
0
]],
id2label
[
rel
[
1
]]]
if
pair
==
[
"question"
,
"answer"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
0
]],
"tail"
:
entity_id_to_index_map
[
rel
[
1
]]
})
elif
pair
==
[
"answer"
,
"question"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
1
]],
"tail"
:
entity_id_to_index_map
[
rel
[
0
]]
})
else
:
continue
relations
=
sorted
(
[{
"head"
:
rel
[
"head"
],
"tail"
:
rel
[
"tail"
],
"start_index"
:
get_relation_span
(
rel
,
entities
)[
0
],
"end_index"
:
get_relation_span
(
rel
,
entities
)[
1
],
}
for
rel
in
kv_relations
],
key
=
lambda
x
:
x
[
"head"
],
)
return
relations
def
load_img
(
self
,
image_path
):
# read img
img
=
cv2
.
imread
(
image_path
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
resize_h
,
resize_w
=
self
.
img_size
im_shape
=
img
.
shape
[
0
:
2
]
im_scale_y
=
resize_h
/
im_shape
[
0
]
im_scale_x
=
resize_w
/
im_shape
[
1
]
img_new
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
2
)
mean
=
np
.
array
([
0.485
,
0.456
,
0.406
])[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
([
0.229
,
0.224
,
0.225
])[
np
.
newaxis
,
np
.
newaxis
,
:]
img_new
=
img_new
/
255.0
img_new
-=
mean
img_new
/=
std
img
=
img_new
.
transpose
((
2
,
0
,
1
))
return
img
def
__getitem__
(
self
,
idx
):
if
self
.
load_mode
==
"all"
:
data
=
copy
.
deepcopy
(
self
.
encoded_inputs_all
[
idx
])
else
:
data
=
self
.
_parse_label_file
(
self
.
all_lines
[
idx
])[
0
]
image_path
=
data
.
pop
(
'image_path'
)
data
[
"image"
]
=
self
.
load_img
(
image_path
)
return_data
=
{}
for
k
,
v
in
data
.
items
():
if
k
in
self
.
return_keys
:
if
self
.
return_keys
[
k
][
'type'
]
==
'np'
:
v
=
np
.
array
(
v
,
dtype
=
self
.
return_keys
[
k
][
'dtype'
])
return_data
[
k
]
=
v
return
return_data
def
__len__
(
self
,
):
if
self
.
load_mode
==
"all"
:
return
len
(
self
.
encoded_inputs_all
)
else
:
return
len
(
self
.
all_lines
)
def
get_relation_span
(
rel
,
entities
):
bound
=
[]
for
entity_index
in
[
rel
[
"head"
],
rel
[
"tail"
]]:
bound
.
append
(
entities
[
entity_index
][
"start"
])
bound
.
append
(
entities
[
entity_index
][
"end"
])
return
min
(
bound
),
max
(
bound
)
def
reformat
(
data
):
new_data
=
{}
for
item
in
data
:
for
k
,
v
in
item
.
items
():
if
k
not
in
new_data
:
new_data
[
k
]
=
[]
new_data
[
k
].
append
(
v
)
return
new_data
requirements.txt
View file @
a323fce6
...
@@ -13,4 +13,3 @@ lxml
...
@@ -13,4 +13,3 @@ lxml
premailer
premailer
openpyxl
openpyxl
fasttext
==0.9.1
fasttext
==0.9.1
paddlenlp
>=2.2.1
tools/eval.py
View file @
a323fce6
...
@@ -61,7 +61,8 @@ def main():
...
@@ -61,7 +61,8 @@ def main():
else
:
else
:
model_type
=
None
model_type
=
None
best_model_dict
=
load_model
(
config
,
model
)
best_model_dict
=
load_model
(
config
,
model
,
model_type
=
config
[
'Architecture'
][
"model_type"
])
if
len
(
best_model_dict
):
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
for
k
,
v
in
best_model_dict
.
items
():
...
...
tools/infer_vqa_token_ser.py
0 → 100755
View file @
a323fce6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
json
import
paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.visual
import
draw_ser_results
from
ppocr.utils.utility
import
get_image_file_list
,
load_vqa_bio_label_maps
import
tools.program
as
program
def
to_tensor
(
data
):
import
numbers
from
collections
import
defaultdict
data_dict
=
defaultdict
(
list
)
to_tensor_idxs
=
[]
for
idx
,
v
in
enumerate
(
data
):
if
isinstance
(
v
,
(
np
.
ndarray
,
paddle
.
Tensor
,
numbers
.
Number
)):
if
idx
not
in
to_tensor_idxs
:
to_tensor_idxs
.
append
(
idx
)
data_dict
[
idx
].
append
(
v
)
for
idx
in
to_tensor_idxs
:
data_dict
[
idx
]
=
paddle
.
to_tensor
(
data_dict
[
idx
])
return
list
(
data_dict
.
values
())
class
SerPredictor
(
object
):
def
__init__
(
self
,
config
):
global_config
=
config
[
'Global'
]
# build post process
self
.
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
self
.
model
=
build_model
(
config
[
'Architecture'
])
load_model
(
config
,
self
.
model
,
model_type
=
config
[
'Architecture'
][
"model_type"
])
from
paddleocr
import
PaddleOCR
self
.
ocr_engine
=
PaddleOCR
(
use_angle_cls
=
False
,
show_log
=
False
)
# create data ops
transforms
=
[]
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
op
[
op_name
][
'ocr_engine'
]
=
self
.
ocr_engine
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'input_ids'
,
'labels'
,
'bbox'
,
'image'
,
'attention_mask'
,
'token_type_ids'
,
'segment_offset_id'
,
'ocr_info'
,
'entities'
]
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
self
.
ops
=
create_operators
(
config
[
'Eval'
][
'dataset'
][
'transforms'
],
global_config
)
self
.
model
.
eval
()
def
__call__
(
self
,
img_path
):
with
open
(
img_path
,
'rb'
)
as
f
:
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
self
.
ops
)
batch
=
to_tensor
(
batch
)
preds
=
self
.
model
(
batch
)
post_result
=
self
.
post_process_class
(
preds
,
attention_masks
=
batch
[
4
],
segment_offset_ids
=
batch
[
6
],
ocr_infos
=
batch
[
7
])
return
post_result
,
batch
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
os
.
makedirs
(
config
[
'Global'
][
'save_res_path'
],
exist_ok
=
True
)
ser_engine
=
SerPredictor
(
config
)
infer_imgs
=
get_image_file_list
(
config
[
'Global'
][
'infer_img'
])
with
open
(
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
result
,
_
=
ser_engine
(
img_path
)
result
=
result
[
0
]
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ocr_info"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img_path
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
tools/infer_vqa_token_ser_re.py
0 → 100755
View file @
a323fce6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
json
import
paddle
import
paddle.distributed
as
dist
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.visual
import
draw_re_results
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
load_vqa_bio_label_maps
,
print_dict
from
tools.program
import
ArgsParser
,
load_config
,
merge_config
,
check_gpu
from
tools.infer_vqa_token_ser
import
SerPredictor
class
ReArgsParser
(
ArgsParser
):
def
__init__
(
self
):
super
(
ReArgsParser
,
self
).
__init__
()
self
.
add_argument
(
"-c_ser"
,
"--config_ser"
,
help
=
"ser configuration file to use"
)
self
.
add_argument
(
"-o_ser"
,
"--opt_ser"
,
nargs
=
'+'
,
help
=
"set ser configuration options "
)
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ReArgsParser
,
self
).
parse_args
(
argv
)
assert
args
.
config_ser
is
not
None
,
\
"Please specify --config_ser=ser_configure_file_path."
args
.
opt_ser
=
self
.
_parse_opt
(
args
.
opt_ser
)
return
args
def
make_input
(
ser_inputs
,
ser_results
):
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
entities
=
ser_inputs
[
8
][
0
]
ser_results
=
ser_results
[
0
]
assert
len
(
entities
)
==
len
(
ser_results
)
# entities
start
=
[]
end
=
[]
label
=
[]
entity_idx_dict
=
{}
for
i
,
(
res
,
entity
)
in
enumerate
(
zip
(
ser_results
,
entities
)):
if
res
[
'pred'
]
==
'O'
:
continue
entity_idx_dict
[
len
(
start
)]
=
i
start
.
append
(
entity
[
'start'
])
end
.
append
(
entity
[
'end'
])
label
.
append
(
entities_labels
[
res
[
'pred'
]])
entities
=
dict
(
start
=
start
,
end
=
end
,
label
=
label
)
# relations
head
=
[]
tail
=
[]
for
i
in
range
(
len
(
entities
[
"label"
])):
for
j
in
range
(
len
(
entities
[
"label"
])):
if
entities
[
"label"
][
i
]
==
1
and
entities
[
"label"
][
j
]
==
2
:
head
.
append
(
i
)
tail
.
append
(
j
)
relations
=
dict
(
head
=
head
,
tail
=
tail
)
batch_size
=
ser_inputs
[
0
].
shape
[
0
]
entities_batch
=
[]
relations_batch
=
[]
entity_idx_dict_batch
=
[]
for
b
in
range
(
batch_size
):
entities_batch
.
append
(
entities
)
relations_batch
.
append
(
relations
)
entity_idx_dict_batch
.
append
(
entity_idx_dict
)
ser_inputs
[
8
]
=
entities_batch
ser_inputs
.
append
(
relations_batch
)
ser_inputs
.
pop
(
7
)
ser_inputs
.
pop
(
6
)
ser_inputs
.
pop
(
1
)
return
ser_inputs
,
entity_idx_dict_batch
class
SerRePredictor
(
object
):
def
__init__
(
self
,
config
,
ser_config
):
self
.
ser_engine
=
SerPredictor
(
ser_config
)
# init re model
global_config
=
config
[
'Global'
]
# build post process
self
.
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
self
.
model
=
build_model
(
config
[
'Architecture'
])
load_model
(
config
,
self
.
model
,
model_type
=
config
[
'Architecture'
][
"model_type"
])
self
.
model
.
eval
()
def
__call__
(
self
,
img_path
):
ser_results
,
ser_inputs
=
self
.
ser_engine
(
img_path
)
paddle
.
save
(
ser_inputs
,
'ser_inputs.npy'
)
paddle
.
save
(
ser_results
,
'ser_results.npy'
)
re_input
,
entity_idx_dict_batch
=
make_input
(
ser_inputs
,
ser_results
)
preds
=
self
.
model
(
re_input
)
post_result
=
self
.
post_process_class
(
preds
,
ser_results
=
ser_results
,
entity_idx_dict_batch
=
entity_idx_dict_batch
)
return
post_result
def
preprocess
():
FLAGS
=
ReArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
config
=
merge_config
(
config
,
FLAGS
.
opt
)
ser_config
=
load_config
(
FLAGS
.
config_ser
)
ser_config
=
merge_config
(
ser_config
,
FLAGS
.
opt_ser
)
logger
=
get_logger
(
name
=
'root'
)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
check_gpu
(
use_gpu
)
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
logger
.
info
(
'{} re config {}'
.
format
(
'*'
*
10
,
'*'
*
10
))
print_dict
(
config
,
logger
)
logger
.
info
(
'
\n
'
)
logger
.
info
(
'{} ser config {}'
.
format
(
'*'
*
10
,
'*'
*
10
))
print_dict
(
ser_config
,
logger
)
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
device
))
return
config
,
ser_config
,
device
,
logger
if
__name__
==
'__main__'
:
config
,
ser_config
,
device
,
logger
=
preprocess
()
os
.
makedirs
(
config
[
'Global'
][
'save_res_path'
],
exist_ok
=
True
)
ser_re_engine
=
SerRePredictor
(
config
,
ser_config
)
infer_imgs
=
get_image_file_list
(
config
[
'Global'
][
'infer_img'
])
with
open
(
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
result
=
ser_re_engine
(
img_path
)
result
=
result
[
0
]
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ser_resule"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_re_results
(
img_path
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
tools/program.py
View file @
a323fce6
...
@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
...
@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
return
config
return
config
class
AttrDict
(
dict
):
"""Single level attribute dict, NOT recursive"""
def
__init__
(
self
,
**
kwargs
):
super
(
AttrDict
,
self
).
__init__
()
super
(
AttrDict
,
self
).
update
(
kwargs
)
def
__getattr__
(
self
,
key
):
if
key
in
self
:
return
self
[
key
]
raise
AttributeError
(
"object has no attribute '{}'"
.
format
(
key
))
global_config
=
AttrDict
()
default_config
=
{
'Global'
:
{
'debug'
:
False
,
}}
def
load_config
(
file_path
):
def
load_config
(
file_path
):
"""
"""
Load config from yml/yaml file.
Load config from yml/yaml file.
...
@@ -94,38 +76,38 @@ def load_config(file_path):
...
@@ -94,38 +76,38 @@ def load_config(file_path):
file_path (str): Path of the config file to be loaded.
file_path (str): Path of the config file to be loaded.
Returns: global config
Returns: global config
"""
"""
merge_config
(
default_config
)
_
,
ext
=
os
.
path
.
splitext
(
file_path
)
_
,
ext
=
os
.
path
.
splitext
(
file_path
)
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
merge_
config
(
yaml
.
load
(
open
(
file_path
,
'rb'
),
Loader
=
yaml
.
Loader
)
)
config
=
yaml
.
load
(
open
(
file_path
,
'rb'
),
Loader
=
yaml
.
Loader
)
return
global_
config
return
config
def
merge_config
(
config
):
def
merge_config
(
config
,
opts
):
"""
"""
Merge config into global config.
Merge config into global config.
Args:
Args:
config (dict): Config to be merged.
config (dict): Config to be merged.
Returns: global config
Returns: global config
"""
"""
for
key
,
value
in
config
.
items
():
for
key
,
value
in
opts
.
items
():
if
"."
not
in
key
:
if
"."
not
in
key
:
if
isinstance
(
value
,
dict
)
and
key
in
global_
config
:
if
isinstance
(
value
,
dict
)
and
key
in
config
:
global_
config
[
key
].
update
(
value
)
config
[
key
].
update
(
value
)
else
:
else
:
global_
config
[
key
]
=
value
config
[
key
]
=
value
else
:
else
:
sub_keys
=
key
.
split
(
'.'
)
sub_keys
=
key
.
split
(
'.'
)
assert
(
assert
(
sub_keys
[
0
]
in
global_
config
sub_keys
[
0
]
in
config
),
"the sub_keys can only be one of global_config: {}, but get: {}, please check your running command"
.
format
(
),
"the sub_keys can only be one of global_config: {}, but get: {}, please check your running command"
.
format
(
global_
config
.
keys
(),
sub_keys
[
0
])
config
.
keys
(),
sub_keys
[
0
])
cur
=
global_
config
[
sub_keys
[
0
]]
cur
=
config
[
sub_keys
[
0
]]
for
idx
,
sub_key
in
enumerate
(
sub_keys
[
1
:]):
for
idx
,
sub_key
in
enumerate
(
sub_keys
[
1
:]):
if
idx
==
len
(
sub_keys
)
-
2
:
if
idx
==
len
(
sub_keys
)
-
2
:
cur
[
sub_key
]
=
value
cur
[
sub_key
]
=
value
else
:
else
:
cur
=
cur
[
sub_key
]
cur
=
cur
[
sub_key
]
return
config
def
check_gpu
(
use_gpu
):
def
check_gpu
(
use_gpu
):
...
@@ -204,20 +186,24 @@ def train(config,
...
@@ -204,20 +186,24 @@ def train(config,
model_type
=
None
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
start_epoch
=
best_model_dict
[
'start_epoch'
]
'start_epoch'
]
if
'start_epoch'
in
best_model_dict
else
1
else
:
start_epoch
=
1
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
train_dataloader
=
build_dataloader
(
if
train_dataloader
.
dataset
.
need_reset
:
config
,
'Train'
,
device
,
logger
,
seed
=
epoch
)
train_dataloader
=
build_dataloader
(
train_reader_cost
=
0.0
config
,
'Train'
,
device
,
logger
,
seed
=
epoch
)
train_run_cost
=
0.0
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
total_samples
=
0
)
==
"Windows"
else
len
(
train_dataloader
)
reader_start
=
time
.
time
()
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
for
idx
,
batch
in
enumerate
(
train_dataloader
):
for
idx
,
batch
in
enumerate
(
train_dataloader
):
profiler
.
add_profiler_step
(
profiler_options
)
profiler
.
add_profiler_step
(
profiler_options
)
train_reader_cost
+=
time
.
time
()
-
reader_start
train_reader_cost
+=
time
.
time
()
-
reader_start
...
@@ -239,10 +225,11 @@ def train(config,
...
@@ -239,10 +225,11 @@ def train(config,
else
:
else
:
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
elif
model_type
==
"kie"
:
elif
model_type
in
[
"kie"
,
'vqa'
]
:
preds
=
model
(
batch
)
preds
=
model
(
batch
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
...
@@ -256,6 +243,7 @@ def train(config,
...
@@ -256,6 +243,7 @@ def train(config,
optimizer
.
clear_grad
()
optimizer
.
clear_grad
()
train_run_cost
+=
time
.
time
()
-
train_start
train_run_cost
+=
time
.
time
()
-
train_start
global_step
+=
1
total_samples
+=
len
(
images
)
total_samples
+=
len
(
images
)
if
not
isinstance
(
lr_scheduler
,
float
):
if
not
isinstance
(
lr_scheduler
,
float
):
...
@@ -285,12 +273,13 @@ def train(config,
...
@@ -285,12 +273,13 @@ def train(config,
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
(
idx
>=
len
(
train_dataloader
)
-
1
)):
(
idx
>=
len
(
train_dataloader
)
-
1
)):
logs
=
train_stats
.
log
()
logs
=
train_stats
.
log
()
strs
=
'epoch: [{}/{}],
i
te
r
: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'
.
format
(
strs
=
'epoch: [{}/{}],
global_s
te
p
: {}, {},
avg_
reader_cost: {:.5f} s,
avg_
batch_cost: {:.5f} s,
avg_
samples: {}, ips: {:.5f}'
.
format
(
epoch
,
epoch_num
,
global_step
,
logs
,
train_reader_cost
/
epoch
,
epoch_num
,
global_step
,
logs
,
train_reader_cost
/
print_batch_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_batch_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_batch_step
,
total_samples
,
print_batch_step
,
total_samples
/
print_batch_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
))
total_samples
/
(
train_reader_cost
+
train_run_cost
))
logger
.
info
(
strs
)
logger
.
info
(
strs
)
train_reader_cost
=
0.0
train_reader_cost
=
0.0
train_run_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
total_samples
=
0
...
@@ -330,6 +319,7 @@ def train(config,
...
@@ -330,6 +319,7 @@ def train(config,
optimizer
,
optimizer
,
save_model_dir
,
save_model_dir
,
logger
,
logger
,
config
,
is_best
=
True
,
is_best
=
True
,
prefix
=
'best_accuracy'
,
prefix
=
'best_accuracy'
,
best_model_dict
=
best_model_dict
,
best_model_dict
=
best_model_dict
,
...
@@ -344,8 +334,7 @@ def train(config,
...
@@ -344,8 +334,7 @@ def train(config,
vdl_writer
.
add_scalar
(
'EVAL/best_{}'
.
format
(
main_indicator
),
vdl_writer
.
add_scalar
(
'EVAL/best_{}'
.
format
(
main_indicator
),
best_model_dict
[
main_indicator
],
best_model_dict
[
main_indicator
],
global_step
)
global_step
)
global_step
+=
1
optimizer
.
clear_grad
()
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
save_model
(
save_model
(
...
@@ -353,6 +342,7 @@ def train(config,
...
@@ -353,6 +342,7 @@ def train(config,
optimizer
,
optimizer
,
save_model_dir
,
save_model_dir
,
logger
,
logger
,
config
,
is_best
=
False
,
is_best
=
False
,
prefix
=
'latest'
,
prefix
=
'latest'
,
best_model_dict
=
best_model_dict
,
best_model_dict
=
best_model_dict
,
...
@@ -364,6 +354,7 @@ def train(config,
...
@@ -364,6 +354,7 @@ def train(config,
optimizer
,
optimizer
,
save_model_dir
,
save_model_dir
,
logger
,
logger
,
config
,
is_best
=
False
,
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
),
prefix
=
'iter_epoch_{}'
.
format
(
epoch
),
best_model_dict
=
best_model_dict
,
best_model_dict
=
best_model_dict
,
...
@@ -401,19 +392,28 @@ def eval(model,
...
@@ -401,19 +392,28 @@ def eval(model,
start
=
time
.
time
()
start
=
time
.
time
()
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
elif
model_type
==
"kie"
:
elif
model_type
in
[
"kie"
,
'vqa'
]
:
preds
=
model
(
batch
)
preds
=
model
(
batch
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch_numpy
=
[]
for
item
in
batch
:
if
isinstance
(
item
,
paddle
.
Tensor
):
batch_numpy
.
append
(
item
.
numpy
())
else
:
batch_numpy
.
append
(
item
)
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
# Evaluate the results of the current batch
if
model_type
in
[
'table'
,
'kie'
]:
if
model_type
in
[
'table'
,
'kie'
]:
eval_class
(
preds
,
batch
)
eval_class
(
preds
,
batch_numpy
)
elif
model_type
in
[
'vqa'
]:
post_result
=
post_process_class
(
preds
,
batch_numpy
)
eval_class
(
post_result
,
batch_numpy
)
else
:
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
post_result
=
post_process_class
(
preds
,
batch
_numpy
[
1
])
eval_class
(
post_result
,
batch
)
eval_class
(
post_result
,
batch
_numpy
)
pbar
.
update
(
1
)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
total_frame
+=
len
(
images
)
...
@@ -479,9 +479,9 @@ def preprocess(is_train=False):
...
@@ -479,9 +479,9 @@ def preprocess(is_train=False):
FLAGS
=
ArgsParser
().
parse_args
()
FLAGS
=
ArgsParser
().
parse_args
()
profiler_options
=
FLAGS
.
profiler_options
profiler_options
=
FLAGS
.
profiler_options
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
config
=
merge_config
(
config
,
FLAGS
.
opt
)
profile_dic
=
{
"profiler_options"
:
FLAGS
.
profiler_options
}
profile_dic
=
{
"profiler_options"
:
FLAGS
.
profiler_options
}
merge_config
(
profile_dic
)
config
=
merge_config
(
config
,
profile_dic
)
if
is_train
:
if
is_train
:
# save_config
# save_config
...
@@ -503,13 +503,8 @@ def preprocess(is_train=False):
...
@@ -503,13 +503,8 @@ def preprocess(is_train=False):
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
]
]
windows_not_support_list
=
[
'PSE'
]
if
platform
.
system
()
==
"Windows"
and
alg
in
windows_not_support_list
:
logger
.
warning
(
'{} is not support in Windows now'
.
format
(
windows_not_support_list
))
sys
.
exit
()
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
device
=
paddle
.
set_device
(
device
)
...
...
tools/train.py
View file @
a323fce6
...
@@ -97,7 +97,8 @@ def main(config, device, logger, vdl_writer):
...
@@ -97,7 +97,8 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
)
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
,
config
[
'Architecture'
][
"model_type"
])
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment