Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
a323fce6
Commit
a323fce6
authored
Jan 05, 2022
by
WenmuZhou
Browse files
vqa code integrated into ppocr training system
parent
1ded2ac4
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
394 additions
and
1870 deletions
+394
-1870
ppstructure/vqa/infer_ser_e2e.py
ppstructure/vqa/infer_ser_e2e.py
+0
-156
ppstructure/vqa/infer_ser_re_e2e.py
ppstructure/vqa/infer_ser_re_e2e.py
+0
-135
ppstructure/vqa/metric.py
ppstructure/vqa/metric.py
+0
-175
ppstructure/vqa/requirements.txt
ppstructure/vqa/requirements.txt
+2
-1
ppstructure/vqa/train_re.py
ppstructure/vqa/train_re.py
+0
-229
ppstructure/vqa/train_ser.py
ppstructure/vqa/train_ser.py
+0
-248
ppstructure/vqa/vqa_utils.py
ppstructure/vqa/vqa_utils.py
+0
-400
ppstructure/vqa/xfun.py
ppstructure/vqa/xfun.py
+0
-464
requirements.txt
requirements.txt
+0
-1
tools/eval.py
tools/eval.py
+2
-1
tools/infer_vqa_token_ser.py
tools/infer_vqa_token_ser.py
+135
-0
tools/infer_vqa_token_ser_re.py
tools/infer_vqa_token_ser_re.py
+199
-0
tools/program.py
tools/program.py
+54
-59
tools/train.py
tools/train.py
+2
-1
No files found.
ppstructure/vqa/infer_ser_e2e.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
import
json
import
cv2
import
numpy
as
np
from
copy
import
deepcopy
from
PIL
import
Image
import
paddle
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
# relative reference
from
vqa_utils
import
parse_args
,
get_image_file_list
,
draw_ser_results
,
get_bio_label_maps
from
vqa_utils
import
pad_sentences
,
split_page
,
preprocess
,
postprocess
,
merge_preds_list_with_ocr_info
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
trans_poly_to_bbox
(
poly
):
x1
=
np
.
min
([
p
[
0
]
for
p
in
poly
])
x2
=
np
.
max
([
p
[
0
]
for
p
in
poly
])
y1
=
np
.
min
([
p
[
1
]
for
p
in
poly
])
y2
=
np
.
max
([
p
[
1
]
for
p
in
poly
])
return
[
x1
,
y1
,
x2
,
y2
]
def
parse_ocr_info_for_ser
(
ocr_result
):
ocr_info
=
[]
for
res
in
ocr_result
:
ocr_info
.
append
({
"text"
:
res
[
1
][
0
],
"bbox"
:
trans_poly_to_bbox
(
res
[
0
]),
"poly"
:
res
[
0
],
})
return
ocr_info
class
SerPredictor
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
max_seq_length
=
args
.
max_seq_length
# init ser token and model
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
self
.
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
self
.
model
.
eval
()
# init ocr_engine
from
paddleocr
import
PaddleOCR
self
.
ocr_engine
=
PaddleOCR
(
rec_model_dir
=
args
.
rec_model_dir
,
det_model_dir
=
args
.
det_model_dir
,
use_angle_cls
=
False
,
show_log
=
False
)
# init dict
label2id_map
,
self
.
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
self
.
label2id_map_for_draw
=
dict
()
for
key
in
label2id_map
:
if
key
.
startswith
(
"I-"
):
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
"B"
+
key
[
1
:]]
else
:
self
.
label2id_map_for_draw
[
key
]
=
label2id_map
[
key
]
def
__call__
(
self
,
img
):
ocr_result
=
self
.
ocr_engine
.
ocr
(
img
,
cls
=
False
)
ocr_info
=
parse_ocr_info_for_ser
(
ocr_result
)
inputs
=
preprocess
(
tokenizer
=
self
.
tokenizer
,
ori_img
=
img
,
ocr_info
=
ocr_info
,
max_seq_len
=
self
.
max_seq_length
)
if
self
.
args
.
ser_model_type
==
'LayoutLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
elif
self
.
args
.
ser_model_type
==
'LayoutXLM'
:
preds
=
self
.
model
(
input_ids
=
inputs
[
"input_ids"
],
bbox
=
inputs
[
"bbox"
],
image
=
inputs
[
"image"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
attention_mask
=
inputs
[
"attention_mask"
])
preds
=
preds
[
0
]
preds
=
postprocess
(
inputs
[
"attention_mask"
],
preds
,
self
.
id2label_map
)
ocr_info
=
merge_preds_list_with_ocr_info
(
ocr_info
,
inputs
[
"segment_offset_id"
],
preds
,
self
.
label2id_map_for_draw
)
return
ocr_info
,
inputs
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# get infer img list
infer_imgs
=
get_image_file_list
(
args
.
infer_imgs
)
# loop for infer
ser_engine
=
SerPredictor
(
args
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
print
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
result
,
_
=
ser_engine
(
img
)
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ser_resule"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
ppstructure/vqa/infer_ser_re_e2e.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
json
import
cv2
import
numpy
as
np
from
copy
import
deepcopy
from
PIL
import
Image
import
paddle
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForRelationExtraction
# relative reference
from
vqa_utils
import
parse_args
,
get_image_file_list
,
draw_re_results
from
infer_ser_e2e
import
SerPredictor
def
make_input
(
ser_input
,
ser_result
,
max_seq_len
=
512
):
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
entities
=
ser_input
[
'entities'
][
0
]
assert
len
(
entities
)
==
len
(
ser_result
)
# entities
start
=
[]
end
=
[]
label
=
[]
entity_idx_dict
=
{}
for
i
,
(
res
,
entity
)
in
enumerate
(
zip
(
ser_result
,
entities
)):
if
res
[
'pred'
]
==
'O'
:
continue
entity_idx_dict
[
len
(
start
)]
=
i
start
.
append
(
entity
[
'start'
])
end
.
append
(
entity
[
'end'
])
label
.
append
(
entities_labels
[
res
[
'pred'
]])
entities
=
dict
(
start
=
start
,
end
=
end
,
label
=
label
)
# relations
head
=
[]
tail
=
[]
for
i
in
range
(
len
(
entities
[
"label"
])):
for
j
in
range
(
len
(
entities
[
"label"
])):
if
entities
[
"label"
][
i
]
==
1
and
entities
[
"label"
][
j
]
==
2
:
head
.
append
(
i
)
tail
.
append
(
j
)
relations
=
dict
(
head
=
head
,
tail
=
tail
)
batch_size
=
ser_input
[
"input_ids"
].
shape
[
0
]
entities_batch
=
[]
relations_batch
=
[]
for
b
in
range
(
batch_size
):
entities_batch
.
append
(
entities
)
relations_batch
.
append
(
relations
)
ser_input
[
'entities'
]
=
entities_batch
ser_input
[
'relations'
]
=
relations_batch
ser_input
.
pop
(
'segment_offset_id'
)
return
ser_input
,
entity_idx_dict
class
SerReSystem
(
object
):
def
__init__
(
self
,
args
):
self
.
ser_engine
=
SerPredictor
(
args
)
self
.
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
re_model_name_or_path
)
self
.
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
re_model_name_or_path
)
self
.
model
.
eval
()
def
__call__
(
self
,
img
):
ser_result
,
ser_inputs
=
self
.
ser_engine
(
img
)
re_input
,
entity_idx_dict
=
make_input
(
ser_inputs
,
ser_result
)
re_result
=
self
.
model
(
**
re_input
)
pred_relations
=
re_result
[
'pred_relations'
][
0
]
# 进行 relations 到 ocr信息的转换
result
=
[]
used_tail_id
=
[]
for
relation
in
pred_relations
:
if
relation
[
'tail_id'
]
in
used_tail_id
:
continue
used_tail_id
.
append
(
relation
[
'tail_id'
])
ocr_info_head
=
ser_result
[
entity_idx_dict
[
relation
[
'head_id'
]]]
ocr_info_tail
=
ser_result
[
entity_idx_dict
[
relation
[
'tail_id'
]]]
result
.
append
((
ocr_info_head
,
ocr_info_tail
))
return
result
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
# get infer img list
infer_imgs
=
get_image_file_list
(
args
.
infer_imgs
)
# loop for infer
ser_re_engine
=
SerReSystem
(
args
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
args
.
output_dir
,
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_re.jpg"
)
print
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
img
=
cv2
.
imread
(
img_path
)
result
=
ser_re_engine
(
img
)
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"result"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_re_results
(
img
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
ppstructure/vqa/metric.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
import
numpy
as
np
import
logging
logger
=
logging
.
getLogger
(
__name__
)
PREFIX_CHECKPOINT_DIR
=
"checkpoint"
_re_checkpoint
=
re
.
compile
(
r
"^"
+
PREFIX_CHECKPOINT_DIR
+
r
"\-(\d+)$"
)
def
get_last_checkpoint
(
folder
):
content
=
os
.
listdir
(
folder
)
checkpoints
=
[
path
for
path
in
content
if
_re_checkpoint
.
search
(
path
)
is
not
None
and
os
.
path
.
isdir
(
os
.
path
.
join
(
folder
,
path
))
]
if
len
(
checkpoints
)
==
0
:
return
return
os
.
path
.
join
(
folder
,
max
(
checkpoints
,
key
=
lambda
x
:
int
(
_re_checkpoint
.
search
(
x
).
groups
()[
0
])))
def
re_score
(
pred_relations
,
gt_relations
,
mode
=
"strict"
):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert
mode
in
[
"strict"
,
"boundaries"
]
relation_types
=
[
v
for
v
in
[
0
,
1
]
if
not
v
==
0
]
scores
=
{
rel
:
{
"tp"
:
0
,
"fp"
:
0
,
"fn"
:
0
}
for
rel
in
relation_types
+
[
"ALL"
]
}
# Count GT relations and Predicted relations
n_sents
=
len
(
gt_relations
)
n_rels
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
gt_relations
])
n_found
=
sum
([
len
([
rel
for
rel
in
sent
])
for
sent
in
pred_relations
])
# Count TP, FP and FN per type
for
pred_sent
,
gt_sent
in
zip
(
pred_relations
,
gt_relations
):
for
rel_type
in
relation_types
:
# strict mode takes argument types into account
if
mode
==
"strict"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"head_type"
],
rel
[
"tail"
],
rel
[
"tail_type"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
# boundaries mode only takes argument spans into account
elif
mode
==
"boundaries"
:
pred_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
pred_sent
if
rel
[
"type"
]
==
rel_type
}
gt_rels
=
{(
rel
[
"head"
],
rel
[
"tail"
])
for
rel
in
gt_sent
if
rel
[
"type"
]
==
rel_type
}
scores
[
rel_type
][
"tp"
]
+=
len
(
pred_rels
&
gt_rels
)
scores
[
rel_type
][
"fp"
]
+=
len
(
pred_rels
-
gt_rels
)
scores
[
rel_type
][
"fn"
]
+=
len
(
gt_rels
-
pred_rels
)
# Compute per entity Precision / Recall / F1
for
rel_type
in
scores
.
keys
():
if
scores
[
rel_type
][
"tp"
]:
scores
[
rel_type
][
"p"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fp"
]
+
scores
[
rel_type
][
"tp"
])
scores
[
rel_type
][
"r"
]
=
scores
[
rel_type
][
"tp"
]
/
(
scores
[
rel_type
][
"fn"
]
+
scores
[
rel_type
][
"tp"
])
else
:
scores
[
rel_type
][
"p"
],
scores
[
rel_type
][
"r"
]
=
0
,
0
if
not
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]
==
0
:
scores
[
rel_type
][
"f1"
]
=
(
2
*
scores
[
rel_type
][
"p"
]
*
scores
[
rel_type
][
"r"
]
/
(
scores
[
rel_type
][
"p"
]
+
scores
[
rel_type
][
"r"
]))
else
:
scores
[
rel_type
][
"f1"
]
=
0
# Compute micro F1 Scores
tp
=
sum
([
scores
[
rel_type
][
"tp"
]
for
rel_type
in
relation_types
])
fp
=
sum
([
scores
[
rel_type
][
"fp"
]
for
rel_type
in
relation_types
])
fn
=
sum
([
scores
[
rel_type
][
"fn"
]
for
rel_type
in
relation_types
])
if
tp
:
precision
=
tp
/
(
tp
+
fp
)
recall
=
tp
/
(
tp
+
fn
)
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
else
:
precision
,
recall
,
f1
=
0
,
0
,
0
scores
[
"ALL"
][
"p"
]
=
precision
scores
[
"ALL"
][
"r"
]
=
recall
scores
[
"ALL"
][
"f1"
]
=
f1
scores
[
"ALL"
][
"tp"
]
=
tp
scores
[
"ALL"
][
"fp"
]
=
fp
scores
[
"ALL"
][
"fn"
]
=
fn
# Compute Macro F1 Scores
scores
[
"ALL"
][
"Macro_f1"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"f1"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_p"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"p"
]
for
ent_type
in
relation_types
])
scores
[
"ALL"
][
"Macro_r"
]
=
np
.
mean
(
[
scores
[
ent_type
][
"r"
]
for
ent_type
in
relation_types
])
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
# logger.info(
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
# n_sents, n_rels, n_found, tp
# )
# )
# logger.info(
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
# )
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
# logger.info(
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
# )
# )
# for rel_type in relation_types:
# logger.info(
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
# rel_type,
# scores[rel_type]["tp"],
# scores[rel_type]["fp"],
# scores[rel_type]["fn"],
# scores[rel_type]["p"],
# scores[rel_type]["r"],
# scores[rel_type]["f1"],
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
# )
# )
return
scores
ppstructure/vqa/requirements.txt
View file @
a323fce6
sentencepiece
yacs
seqeval
\ No newline at end of file
seqeval
paddlenlp>=2.2.1
\ No newline at end of file
ppstructure/vqa/train_re.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
time
import
numpy
as
np
import
paddle
from
paddlenlp.transformers
import
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForRelationExtraction
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
,
set_seed
from
data_collator
import
DataCollator
from
eval_re
import
evaluate
from
ppocr.utils.logging
import
get_logger
def
train
(
args
):
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
print_arguments
(
args
,
logger
)
# Added here for reproducibility (even between python 2 and 3)
set_seed
(
args
.
seed
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
# dist mode
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
if
not
args
.
resume
:
model
=
LayoutXLMModel
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
LayoutXLMForRelationExtraction
(
model
,
dropout
=
None
)
logger
.
info
(
'train from scratch'
)
else
:
logger
.
info
(
'resume from {}'
.
format
(
args
.
model_name_or_path
))
model
=
LayoutXLMForRelationExtraction
.
from_pretrained
(
args
.
model_name_or_path
)
# dist mode
if
distributed
:
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
train_data_dir
,
label_path
=
args
.
train_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
max_seq_len
=
args
.
max_seq_length
,
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
True
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
DataCollator
())
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
shuffle
=
False
,
collate_fn
=
DataCollator
())
t_total
=
len
(
train_dataloader
)
*
args
.
num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
args
.
learning_rate
,
decay_steps
=
t_total
,
end_lr
=
0.0
,
power
=
1.0
)
if
args
.
warmup_steps
>
0
:
lr_scheduler
=
paddle
.
optimizer
.
lr
.
LinearWarmup
(
lr_scheduler
,
args
.
warmup_steps
,
start_lr
=
0
,
end_lr
=
args
.
learning_rate
,
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
10
)
optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
,
parameters
=
model
.
parameters
(),
epsilon
=
args
.
adam_epsilon
,
grad_clip
=
grad_clip
,
weight_decay
=
args
.
weight_decay
)
# Train!
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = {}"
.
format
(
len
(
train_dataset
)))
logger
.
info
(
" Num Epochs = {}"
.
format
(
args
.
num_train_epochs
))
logger
.
info
(
" Instantaneous batch size per GPU = {}"
.
format
(
args
.
per_gpu_train_batch_size
))
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = {}"
.
format
(
args
.
per_gpu_train_batch_size
*
paddle
.
distributed
.
get_world_size
()))
logger
.
info
(
" Total optimization steps = {}"
.
format
(
t_total
))
global_step
=
0
model
.
clear_gradients
()
train_dataloader_len
=
len
(
train_dataloader
)
best_metirc
=
{
'f1'
:
0
}
model
.
train
()
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
print_step
=
1
for
epoch
in
range
(
int
(
args
.
num_train_epochs
)):
for
step
,
batch
in
enumerate
(
train_dataloader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
train_start
=
time
.
time
()
outputs
=
model
(
**
batch
)
train_run_cost
+=
time
.
time
()
-
train_start
# model outputs are always tuple in ppnlp (see doc)
loss
=
outputs
[
'loss'
]
loss
=
loss
.
mean
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
()
# lr_scheduler.step() # Update learning rate schedule
global_step
+=
1
total_samples
+=
batch
[
'image'
].
shape
[
0
]
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch
,
args
.
num_train_epochs
,
step
,
train_dataloader_len
,
global_step
,
np
.
mean
(
loss
.
numpy
()),
optimizer
.
get_lr
(),
train_reader_cost
/
print_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_step
,
total_samples
/
print_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
)))
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results
=
evaluate
(
model
,
eval_dataloader
,
logger
)
if
results
[
'f1'
]
>=
best_metirc
[
'f1'
]:
best_metirc
=
results
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"eval results: {}"
.
format
(
results
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
train
(
args
)
ppstructure/vqa/train_ser.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
random
import
time
import
copy
import
logging
import
argparse
import
paddle
import
numpy
as
np
from
seqeval.metrics
import
classification_report
,
f1_score
,
precision_score
,
recall_score
from
paddlenlp.transformers
import
LayoutXLMModel
,
LayoutXLMTokenizer
,
LayoutXLMForTokenClassification
from
paddlenlp.transformers
import
LayoutLMModel
,
LayoutLMTokenizer
,
LayoutLMForTokenClassification
from
xfun
import
XFUNDataset
from
vqa_utils
import
parse_args
,
get_bio_label_maps
,
print_arguments
,
set_seed
from
eval_ser
import
evaluate
from
losses
import
SERLoss
from
ppocr.utils.logging
import
get_logger
MODELS
=
{
'LayoutXLM'
:
(
LayoutXLMTokenizer
,
LayoutXLMModel
,
LayoutXLMForTokenClassification
),
'LayoutLM'
:
(
LayoutLMTokenizer
,
LayoutLMModel
,
LayoutLMForTokenClassification
)
}
def
train
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
print_arguments
(
args
,
logger
)
label2id_map
,
id2label_map
=
get_bio_label_maps
(
args
.
label_map_path
)
loss_class
=
SERLoss
(
len
(
label2id_map
))
pad_token_label_id
=
loss_class
.
ignore_index
# dist mode
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
tokenizer_class
,
base_model_class
,
model_class
=
MODELS
[
args
.
ser_model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
)
if
not
args
.
resume
:
base_model
=
base_model_class
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
model_class
(
base_model
,
num_classes
=
len
(
label2id_map
),
dropout
=
None
)
logger
.
info
(
'train from scratch'
)
else
:
logger
.
info
(
'resume from {}'
.
format
(
args
.
model_name_or_path
))
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
)
# dist mode
if
distributed
:
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
train_data_dir
,
label_path
=
args
.
train_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
False
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
eval_dataset
=
XFUNDataset
(
tokenizer
,
data_dir
=
args
.
eval_data_dir
,
label_path
=
args
.
eval_label_path
,
label2id_map
=
label2id_map
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
pad_token_label_id
,
contains_re
=
False
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
)
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
batch_sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
None
,
)
eval_dataloader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_size
=
args
.
per_gpu_eval_batch_size
,
num_workers
=
args
.
num_workers
,
use_shared_memory
=
True
,
collate_fn
=
None
,
)
t_total
=
len
(
train_dataloader
)
*
args
.
num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler
=
paddle
.
optimizer
.
lr
.
PolynomialDecay
(
learning_rate
=
args
.
learning_rate
,
decay_steps
=
t_total
,
end_lr
=
0.0
,
power
=
1.0
)
if
args
.
warmup_steps
>
0
:
lr_scheduler
=
paddle
.
optimizer
.
lr
.
LinearWarmup
(
lr_scheduler
,
args
.
warmup_steps
,
start_lr
=
0
,
end_lr
=
args
.
learning_rate
,
)
optimizer
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
lr_scheduler
,
parameters
=
model
.
parameters
(),
epsilon
=
args
.
adam_epsilon
,
weight_decay
=
args
.
weight_decay
)
# Train!
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed) = %d"
,
args
.
per_gpu_train_batch_size
*
paddle
.
distributed
.
get_world_size
(),
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
tr_loss
=
0.0
set_seed
(
args
.
seed
)
best_metrics
=
None
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
print_step
=
1
model
.
train
()
for
epoch_id
in
range
(
args
.
num_train_epochs
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
train_reader_cost
+=
time
.
time
()
-
reader_start
if
args
.
ser_model_type
==
'LayoutLM'
:
if
'image'
in
batch
:
batch
.
pop
(
'image'
)
labels
=
batch
.
pop
(
'labels'
)
train_start
=
time
.
time
()
outputs
=
model
(
**
batch
)
train_run_cost
+=
time
.
time
()
-
train_start
if
args
.
ser_model_type
==
'LayoutXLM'
:
outputs
=
outputs
[
0
]
loss
=
loss_class
(
labels
,
outputs
,
batch
[
'attention_mask'
])
# model outputs are always tuple in ppnlp (see doc)
loss
=
loss
.
mean
()
loss
.
backward
()
tr_loss
+=
loss
.
item
()
optimizer
.
step
()
lr_scheduler
.
step
()
# Update learning rate schedule
optimizer
.
clear_grad
()
global_step
+=
1
total_samples
+=
batch
[
'input_ids'
].
shape
[
0
]
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
global_step
,
loss
.
numpy
()[
0
],
lr_scheduler
.
get_lr
(),
train_reader_cost
/
print_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_step
,
total_samples
/
print_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
)))
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
loss_class
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
)
if
best_metrics
is
None
or
results
[
"f1"
]
>=
best_metrics
[
"f1"
]:
best_metrics
=
copy
.
deepcopy
(
results
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"[epoch {}/{}][iter: {}/{}] results: {}"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
results
))
if
best_metrics
is
not
None
:
logger
.
info
(
"best metrics: {}"
.
format
(
best_metrics
))
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
return
global_step
,
tr_loss
/
global_step
if
__name__
==
"__main__"
:
args
=
parse_args
()
train
(
args
)
ppstructure/vqa/vqa_utils.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
cv2
import
random
import
numpy
as
np
import
imghdr
from
copy
import
deepcopy
import
paddle
from
PIL
import
Image
,
ImageDraw
,
ImageFont
def
set_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
def
get_bio_label_maps
(
label_map_path
):
with
open
(
label_map_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
lines
=
[
line
.
strip
()
for
line
in
lines
]
if
"O"
not
in
lines
:
lines
.
insert
(
0
,
"O"
)
labels
=
[]
for
line
in
lines
:
if
line
==
"O"
:
labels
.
append
(
"O"
)
else
:
labels
.
append
(
"B-"
+
line
)
labels
.
append
(
"I-"
+
line
)
label2id_map
=
{
label
:
idx
for
idx
,
label
in
enumerate
(
labels
)}
id2label_map
=
{
idx
:
label
for
idx
,
label
in
enumerate
(
labels
)}
return
label2id_map
,
id2label_map
def
get_image_file_list
(
img_file
):
imgs_lists
=
[]
if
img_file
is
None
or
not
os
.
path
.
exists
(
img_file
):
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
img_end
=
{
'jpg'
,
'bmp'
,
'png'
,
'jpeg'
,
'rgb'
,
'tif'
,
'tiff'
,
'gif'
,
'GIF'
}
if
os
.
path
.
isfile
(
img_file
)
and
imghdr
.
what
(
img_file
)
in
img_end
:
imgs_lists
.
append
(
img_file
)
elif
os
.
path
.
isdir
(
img_file
):
for
single_file
in
os
.
listdir
(
img_file
):
file_path
=
os
.
path
.
join
(
img_file
,
single_file
)
if
os
.
path
.
isfile
(
file_path
)
and
imghdr
.
what
(
file_path
)
in
img_end
:
imgs_lists
.
append
(
file_path
)
if
len
(
imgs_lists
)
==
0
:
raise
Exception
(
"not found any img file in {}"
.
format
(
img_file
))
imgs_lists
=
sorted
(
imgs_lists
)
return
imgs_lists
def
draw_ser_results
(
image
,
ocr_results
,
font_path
=
"../../doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
2021
)
color
=
(
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)),
np
.
random
.
permutation
(
range
(
255
)))
color_map
=
{
idx
:
(
color
[
0
][
idx
],
color
[
1
][
idx
],
color
[
2
][
idx
])
for
idx
in
range
(
1
,
255
)
}
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
for
ocr_info
in
ocr_results
:
if
ocr_info
[
"pred_id"
]
not
in
color_map
:
continue
color
=
color_map
[
ocr_info
[
"pred_id"
]]
text
=
"{}: {}"
.
format
(
ocr_info
[
"pred"
],
ocr_info
[
"text"
])
draw_box_txt
(
ocr_info
[
"bbox"
],
text
,
draw
,
font
,
font_size
,
color
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
def
draw_box_txt
(
bbox
,
text
,
draw
,
font
,
font_size
,
color
):
# draw ocr results outline
bbox
=
((
bbox
[
0
],
bbox
[
1
]),
(
bbox
[
2
],
bbox
[
3
]))
draw
.
rectangle
(
bbox
,
fill
=
color
)
# draw ocr results
start_y
=
max
(
0
,
bbox
[
0
][
1
]
-
font_size
)
tw
=
font
.
getsize
(
text
)[
0
]
draw
.
rectangle
(
[(
bbox
[
0
][
0
]
+
1
,
start_y
),
(
bbox
[
0
][
0
]
+
tw
+
1
,
start_y
+
font_size
)],
fill
=
(
0
,
0
,
255
))
draw
.
text
((
bbox
[
0
][
0
]
+
1
,
start_y
),
text
,
fill
=
(
255
,
255
,
255
),
font
=
font
)
def
draw_re_results
(
image
,
result
,
font_path
=
"../../doc/fonts/simfang.ttf"
,
font_size
=
18
):
np
.
random
.
seed
(
0
)
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
img_new
=
image
.
copy
()
draw
=
ImageDraw
.
Draw
(
img_new
)
font
=
ImageFont
.
truetype
(
font_path
,
font_size
,
encoding
=
"utf-8"
)
color_head
=
(
0
,
0
,
255
)
color_tail
=
(
255
,
0
,
0
)
color_line
=
(
0
,
255
,
0
)
for
ocr_info_head
,
ocr_info_tail
in
result
:
draw_box_txt
(
ocr_info_head
[
"bbox"
],
ocr_info_head
[
"text"
],
draw
,
font
,
font_size
,
color_head
)
draw_box_txt
(
ocr_info_tail
[
"bbox"
],
ocr_info_tail
[
"text"
],
draw
,
font
,
font_size
,
color_tail
)
center_head
=
(
(
ocr_info_head
[
'bbox'
][
0
]
+
ocr_info_head
[
'bbox'
][
2
])
//
2
,
(
ocr_info_head
[
'bbox'
][
1
]
+
ocr_info_head
[
'bbox'
][
3
])
//
2
)
center_tail
=
(
(
ocr_info_tail
[
'bbox'
][
0
]
+
ocr_info_tail
[
'bbox'
][
2
])
//
2
,
(
ocr_info_tail
[
'bbox'
][
1
]
+
ocr_info_tail
[
'bbox'
][
3
])
//
2
)
draw
.
line
([
center_head
,
center_tail
],
fill
=
color_line
,
width
=
5
)
img_new
=
Image
.
blend
(
image
,
img_new
,
0.5
)
return
np
.
array
(
img_new
)
# pad sentences
def
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
# Padding with larger size, reshape is carried out
max_seq_len
=
(
len
(
encoded_inputs
[
"input_ids"
])
//
max_seq_len
+
1
)
*
max_seq_len
needs_to_be_padded
=
pad_to_max_seq_len
and
\
max_seq_len
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_seq_len
if
needs_to_be_padded
:
difference
=
max_seq_len
-
len
(
encoded_inputs
[
"input_ids"
])
if
tokenizer
.
padding_side
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
encoded_inputs
[
"token_type_ids"
]
+
[
tokenizer
.
pad_token_type_id
]
*
difference
)
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
tokenizer
.
pad_token_id
]
*
difference
encoded_inputs
[
"bbox"
]
=
encoded_inputs
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]
]
*
difference
else
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
return
encoded_inputs
def
split_page
(
encoded_inputs
,
max_seq_len
=
512
):
"""
truncate is often used in training process
"""
for
key
in
encoded_inputs
:
if
key
==
'entities'
:
encoded_inputs
[
key
]
=
[
encoded_inputs
[
key
]]
continue
encoded_inputs
[
key
]
=
paddle
.
to_tensor
(
encoded_inputs
[
key
])
if
encoded_inputs
[
key
].
ndim
<=
1
:
# for input_ids, att_mask and so on
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
([
-
1
,
max_seq_len
])
else
:
# for bbox
encoded_inputs
[
key
]
=
encoded_inputs
[
key
].
reshape
(
[
-
1
,
max_seq_len
,
4
])
return
encoded_inputs
def
preprocess
(
tokenizer
,
ori_img
,
ocr_info
,
img_size
=
(
224
,
224
),
pad_token_label_id
=-
100
,
max_seq_len
=
512
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
):
ocr_info
=
deepcopy
(
ocr_info
)
height
=
ori_img
.
shape
[
0
]
width
=
ori_img
.
shape
[
1
]
img
=
cv2
.
resize
(
ori_img
,
img_size
).
transpose
([
2
,
0
,
1
]).
astype
(
np
.
float32
)
segment_offset_id
=
[]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
entities
=
[]
for
info
in
ocr_info
:
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
if
not
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
# for re
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
"O"
,
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
words_list
.
append
(
text
)
segment_offset_id
.
append
(
len
(
input_ids_list
))
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
"entities"
:
entities
}
encoded_inputs
=
pad_sentences
(
tokenizer
,
encoded_inputs
,
max_seq_len
=
max_seq_len
,
return_attention_mask
=
return_attention_mask
)
encoded_inputs
=
split_page
(
encoded_inputs
)
fake_bs
=
encoded_inputs
[
"input_ids"
].
shape
[
0
]
encoded_inputs
[
"image"
]
=
paddle
.
to_tensor
(
img
).
unsqueeze
(
0
).
expand
(
[
fake_bs
]
+
list
(
img
.
shape
))
encoded_inputs
[
"segment_offset_id"
]
=
segment_offset_id
return
encoded_inputs
def
postprocess
(
attention_mask
,
preds
,
id2label_map
):
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
preds_list
=
[[]
for
_
in
range
(
preds
.
shape
[
0
])]
# keep batch info
for
i
in
range
(
preds
.
shape
[
0
]):
for
j
in
range
(
preds
.
shape
[
1
]):
if
attention_mask
[
i
][
j
]
==
1
:
preds_list
[
i
].
append
(
id2label_map
[
preds
[
i
][
j
]])
return
preds_list
def
merge_preds_list_with_ocr_info
(
ocr_info
,
segment_offset_id
,
preds_list
,
label2id_map_for_draw
):
# must ensure the preds_list is generated from the same image
preds
=
[
p
for
pred
in
preds_list
for
p
in
pred
]
id2label_map
=
dict
()
for
key
in
label2id_map_for_draw
:
val
=
label2id_map_for_draw
[
key
]
if
key
==
"O"
:
id2label_map
[
val
]
=
key
if
key
.
startswith
(
"B-"
)
or
key
.
startswith
(
"I-"
):
id2label_map
[
val
]
=
key
[
2
:]
else
:
id2label_map
[
val
]
=
key
for
idx
in
range
(
len
(
segment_offset_id
)):
if
idx
==
0
:
start_id
=
0
else
:
start_id
=
segment_offset_id
[
idx
-
1
]
end_id
=
segment_offset_id
[
idx
]
curr_pred
=
preds
[
start_id
:
end_id
]
curr_pred
=
[
label2id_map_for_draw
[
p
]
for
p
in
curr_pred
]
if
len
(
curr_pred
)
<=
0
:
pred_id
=
0
else
:
counts
=
np
.
bincount
(
curr_pred
)
pred_id
=
np
.
argmax
(
counts
)
ocr_info
[
idx
][
"pred_id"
]
=
int
(
pred_id
)
ocr_info
[
idx
][
"pred"
]
=
id2label_map
[
int
(
pred_id
)]
return
ocr_info
def
print_arguments
(
args
,
logger
=
None
):
print_func
=
logger
.
info
if
logger
is
not
None
else
print
"""print arguments"""
print_func
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
vars
(
args
).
items
()):
print_func
(
'%s: %s'
%
(
arg
,
value
))
print_func
(
'------------------------------------------------'
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
# Required parameters
# yapf: disable
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,)
parser
.
add_argument
(
"--ser_model_type"
,
default
=
'LayoutXLM'
,
type
=
str
)
parser
.
add_argument
(
"--re_model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--train_data_dir"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--train_label_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--eval_data_dir"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--eval_label_path"
,
default
=
None
,
type
=
str
,
required
=
False
,)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
512
,
type
=
int
,)
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
"store_true"
,)
parser
.
add_argument
(
"--num_workers"
,
default
=
8
,
type
=
int
,)
parser
.
add_argument
(
"--per_gpu_train_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
,)
parser
.
add_argument
(
"--per_gpu_eval_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU/CPU for eval."
,)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
,)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
,)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
,)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
,)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
,)
parser
.
add_argument
(
"--eval_steps"
,
type
=
int
,
default
=
10
,
help
=
"eval every X updates steps."
,)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
2048
,
help
=
"random seed for initialization"
,)
parser
.
add_argument
(
"--rec_model_dir"
,
default
=
None
,
type
=
str
,
)
parser
.
add_argument
(
"--det_model_dir"
,
default
=
None
,
type
=
str
,
)
parser
.
add_argument
(
"--label_map_path"
,
default
=
"./labels/labels_ser.txt"
,
type
=
str
,
required
=
False
,
)
parser
.
add_argument
(
"--infer_imgs"
,
default
=
None
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--resume"
,
action
=
'store_true'
)
parser
.
add_argument
(
"--ocr_json_path"
,
default
=
None
,
type
=
str
,
required
=
False
,
help
=
"ocr prediction results"
)
# yapf: enable
args
=
parser
.
parse_args
()
return
args
ppstructure/vqa/xfun.py
deleted
100644 → 0
View file @
1ded2ac4
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
cv2
import
numpy
as
np
import
paddle
import
copy
from
paddle.io
import
Dataset
__all__
=
[
"XFUNDataset"
]
class
XFUNDataset
(
Dataset
):
"""
Example:
print("=====begin to build dataset=====")
from paddlenlp.transformers import LayoutXLMTokenizer
tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
tok_res = tokenizer.tokenize("Maribyrnong")
# res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
dataset = XfunDatasetForSer(
tokenizer,
data_dir="./zh.val/",
label_path="zh.val/xfun_normalize_val.json",
img_size=(224,224))
print(len(dataset))
data = dataset[0]
print(data.keys())
print("input_ids: ", data["input_ids"])
print("labels: ", data["labels"])
print("token_type_ids: ", data["token_type_ids"])
print("words_list: ", data["words_list"])
print("image shape: ", data["image"].shape)
"""
def
__init__
(
self
,
tokenizer
,
data_dir
,
label_path
,
contains_re
=
False
,
label2id_map
=
None
,
img_size
=
(
224
,
224
),
pad_token_label_id
=
None
,
add_special_ids
=
False
,
return_attention_mask
=
True
,
load_mode
=
'all'
,
max_seq_len
=
512
):
super
().
__init__
()
self
.
tokenizer
=
tokenizer
self
.
data_dir
=
data_dir
self
.
label_path
=
label_path
self
.
contains_re
=
contains_re
self
.
label2id_map
=
label2id_map
self
.
img_size
=
img_size
self
.
pad_token_label_id
=
pad_token_label_id
self
.
add_special_ids
=
add_special_ids
self
.
return_attention_mask
=
return_attention_mask
self
.
load_mode
=
load_mode
self
.
max_seq_len
=
max_seq_len
if
self
.
pad_token_label_id
is
None
:
self
.
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
self
.
all_lines
=
self
.
read_all_lines
()
self
.
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
self
.
return_keys
=
{
'bbox'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'input_ids'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'labels'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'attention_mask'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'image'
:
{
'type'
:
'np'
,
'dtype'
:
'float32'
},
'token_type_ids'
:
{
'type'
:
'np'
,
'dtype'
:
'int64'
},
'entities'
:
{
'type'
:
'dict'
},
'relations'
:
{
'type'
:
'dict'
}
}
if
load_mode
==
"all"
:
self
.
encoded_inputs_all
=
self
.
_parse_label_file_all
()
def
pad_sentences
(
self
,
encoded_inputs
,
max_seq_len
=
512
,
pad_to_max_seq_len
=
True
,
return_attention_mask
=
True
,
return_token_type_ids
=
True
,
truncation_strategy
=
"longest_first"
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
# Padding
needs_to_be_padded
=
pad_to_max_seq_len
and
\
max_seq_len
and
len
(
encoded_inputs
[
"input_ids"
])
<
max_seq_len
if
needs_to_be_padded
:
difference
=
max_seq_len
-
len
(
encoded_inputs
[
"input_ids"
])
if
self
.
tokenizer
.
padding_side
==
'right'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
+
[
0
]
*
difference
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
encoded_inputs
[
"token_type_ids"
]
+
[
self
.
tokenizer
.
pad_token_type_id
]
*
difference
)
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
encoded_inputs
[
"special_tokens_mask"
]
+
[
1
]
*
difference
encoded_inputs
[
"input_ids"
]
=
encoded_inputs
[
"input_ids"
]
+
[
self
.
tokenizer
.
pad_token_id
]
*
difference
encoded_inputs
[
"labels"
]
=
encoded_inputs
[
"labels"
]
+
[
self
.
pad_token_label_id
]
*
difference
encoded_inputs
[
"bbox"
]
=
encoded_inputs
[
"bbox"
]
+
[[
0
,
0
,
0
,
0
]]
*
difference
elif
self
.
tokenizer
.
padding_side
==
'left'
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
0
]
*
difference
+
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
if
return_token_type_ids
:
encoded_inputs
[
"token_type_ids"
]
=
(
[
self
.
tokenizer
.
pad_token_type_id
]
*
difference
+
encoded_inputs
[
"token_type_ids"
])
if
return_special_tokens_mask
:
encoded_inputs
[
"special_tokens_mask"
]
=
[
1
]
*
difference
+
encoded_inputs
[
"special_tokens_mask"
]
encoded_inputs
[
"input_ids"
]
=
[
self
.
tokenizer
.
pad_token_id
]
*
difference
+
encoded_inputs
[
"input_ids"
]
encoded_inputs
[
"labels"
]
=
[
self
.
pad_token_label_id
]
*
difference
+
encoded_inputs
[
"labels"
]
encoded_inputs
[
"bbox"
]
=
[
[
0
,
0
,
0
,
0
]
]
*
difference
+
encoded_inputs
[
"bbox"
]
else
:
if
return_attention_mask
:
encoded_inputs
[
"attention_mask"
]
=
[
1
]
*
len
(
encoded_inputs
[
"input_ids"
])
return
encoded_inputs
def
truncate_inputs
(
self
,
encoded_inputs
,
max_seq_len
=
512
):
for
key
in
encoded_inputs
:
if
key
==
"sample_id"
:
continue
length
=
min
(
len
(
encoded_inputs
[
key
]),
max_seq_len
)
encoded_inputs
[
key
]
=
encoded_inputs
[
key
][:
length
]
return
encoded_inputs
def
read_all_lines
(
self
,
):
with
open
(
self
.
label_path
,
"r"
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
return
lines
def
_parse_label_file_all
(
self
):
"""
parse all samples
"""
encoded_inputs_all
=
[]
for
line
in
self
.
all_lines
:
encoded_inputs_all
.
extend
(
self
.
_parse_label_file
(
line
))
return
encoded_inputs_all
def
_parse_label_file
(
self
,
line
):
"""
parse single sample
"""
image_name
,
info_str
=
line
.
split
(
"
\t
"
)
image_path
=
os
.
path
.
join
(
self
.
data_dir
,
image_name
)
def
add_imgge_path
(
x
):
x
[
'image_path'
]
=
image_path
return
x
encoded_inputs
=
self
.
_read_encoded_inputs_sample
(
info_str
)
if
self
.
contains_re
:
encoded_inputs
=
self
.
_chunk_re
(
encoded_inputs
)
else
:
encoded_inputs
=
self
.
_chunk_ser
(
encoded_inputs
)
encoded_inputs
=
list
(
map
(
add_imgge_path
,
encoded_inputs
))
return
encoded_inputs
def
_read_encoded_inputs_sample
(
self
,
info_str
):
"""
parse label info
"""
# read text info
info_dict
=
json
.
loads
(
info_str
)
height
=
info_dict
[
"height"
]
width
=
info_dict
[
"width"
]
words_list
=
[]
bbox_list
=
[]
input_ids_list
=
[]
token_type_ids_list
=
[]
gt_label_list
=
[]
if
self
.
contains_re
:
# for re
entities
=
[]
relations
=
[]
id2label
=
{}
entity_id_to_index_map
=
{}
empty_entity
=
set
()
for
info
in
info_dict
[
"ocr_info"
]:
if
self
.
contains_re
:
# for re
if
len
(
info
[
"text"
])
==
0
:
empty_entity
.
add
(
info
[
"id"
])
continue
id2label
[
info
[
"id"
]]
=
info
[
"label"
]
relations
.
extend
([
tuple
(
sorted
(
l
))
for
l
in
info
[
"linking"
]])
# x1, y1, x2, y2
bbox
=
info
[
"bbox"
]
label
=
info
[
"label"
]
bbox
[
0
]
=
int
(
bbox
[
0
]
*
1000.0
/
width
)
bbox
[
2
]
=
int
(
bbox
[
2
]
*
1000.0
/
width
)
bbox
[
1
]
=
int
(
bbox
[
1
]
*
1000.0
/
height
)
bbox
[
3
]
=
int
(
bbox
[
3
]
*
1000.0
/
height
)
text
=
info
[
"text"
]
encode_res
=
self
.
tokenizer
.
encode
(
text
,
pad_to_max_seq_len
=
False
,
return_attention_mask
=
True
)
gt_label
=
[]
if
not
self
.
add_special_ids
:
# TODO: use tok.all_special_ids to remove
encode_res
[
"input_ids"
]
=
encode_res
[
"input_ids"
][
1
:
-
1
]
encode_res
[
"token_type_ids"
]
=
encode_res
[
"token_type_ids"
][
1
:
-
1
]
encode_res
[
"attention_mask"
]
=
encode_res
[
"attention_mask"
][
1
:
-
1
]
if
label
.
lower
()
==
"other"
:
gt_label
.
extend
([
0
]
*
len
(
encode_res
[
"input_ids"
]))
else
:
gt_label
.
append
(
self
.
label2id_map
[(
"b-"
+
label
).
upper
()])
gt_label
.
extend
([
self
.
label2id_map
[(
"i-"
+
label
).
upper
()]]
*
(
len
(
encode_res
[
"input_ids"
])
-
1
))
if
self
.
contains_re
:
if
gt_label
[
0
]
!=
self
.
label2id_map
[
"O"
]:
entity_id_to_index_map
[
info
[
"id"
]]
=
len
(
entities
)
entities
.
append
({
"start"
:
len
(
input_ids_list
),
"end"
:
len
(
input_ids_list
)
+
len
(
encode_res
[
"input_ids"
]),
"label"
:
label
.
upper
(),
})
input_ids_list
.
extend
(
encode_res
[
"input_ids"
])
token_type_ids_list
.
extend
(
encode_res
[
"token_type_ids"
])
bbox_list
.
extend
([
bbox
]
*
len
(
encode_res
[
"input_ids"
]))
gt_label_list
.
extend
(
gt_label
)
words_list
.
append
(
text
)
encoded_inputs
=
{
"input_ids"
:
input_ids_list
,
"labels"
:
gt_label_list
,
"token_type_ids"
:
token_type_ids_list
,
"bbox"
:
bbox_list
,
"attention_mask"
:
[
1
]
*
len
(
input_ids_list
),
# "words_list": words_list,
}
encoded_inputs
=
self
.
pad_sentences
(
encoded_inputs
,
max_seq_len
=
self
.
max_seq_len
,
return_attention_mask
=
self
.
return_attention_mask
)
encoded_inputs
=
self
.
truncate_inputs
(
encoded_inputs
)
if
self
.
contains_re
:
relations
=
self
.
_relations
(
entities
,
relations
,
id2label
,
empty_entity
,
entity_id_to_index_map
)
encoded_inputs
[
'relations'
]
=
relations
encoded_inputs
[
'entities'
]
=
entities
return
encoded_inputs
def
_chunk_ser
(
self
,
encoded_inputs
):
encoded_inputs_all
=
[]
seq_len
=
len
(
encoded_inputs
[
'input_ids'
])
chunk_size
=
512
for
chunk_id
,
index
in
enumerate
(
range
(
0
,
seq_len
,
chunk_size
)):
chunk_beg
=
index
chunk_end
=
min
(
index
+
chunk_size
,
seq_len
)
encoded_inputs_example
=
{}
for
key
in
encoded_inputs
:
encoded_inputs_example
[
key
]
=
encoded_inputs
[
key
][
chunk_beg
:
chunk_end
]
encoded_inputs_all
.
append
(
encoded_inputs_example
)
return
encoded_inputs_all
def
_chunk_re
(
self
,
encoded_inputs
):
# prepare data
entities
=
encoded_inputs
.
pop
(
'entities'
)
relations
=
encoded_inputs
.
pop
(
'relations'
)
encoded_inputs_all
=
[]
chunk_size
=
512
for
chunk_id
,
index
in
enumerate
(
range
(
0
,
len
(
encoded_inputs
[
"input_ids"
]),
chunk_size
)):
item
=
{}
for
k
in
encoded_inputs
:
item
[
k
]
=
encoded_inputs
[
k
][
index
:
index
+
chunk_size
]
# select entity in current chunk
entities_in_this_span
=
[]
global_to_local_map
=
{}
#
for
entity_id
,
entity
in
enumerate
(
entities
):
if
(
index
<=
entity
[
"start"
]
<
index
+
chunk_size
and
index
<=
entity
[
"end"
]
<
index
+
chunk_size
):
entity
[
"start"
]
=
entity
[
"start"
]
-
index
entity
[
"end"
]
=
entity
[
"end"
]
-
index
global_to_local_map
[
entity_id
]
=
len
(
entities_in_this_span
)
entities_in_this_span
.
append
(
entity
)
# select relations in current chunk
relations_in_this_span
=
[]
for
relation
in
relations
:
if
(
index
<=
relation
[
"start_index"
]
<
index
+
chunk_size
and
index
<=
relation
[
"end_index"
]
<
index
+
chunk_size
):
relations_in_this_span
.
append
({
"head"
:
global_to_local_map
[
relation
[
"head"
]],
"tail"
:
global_to_local_map
[
relation
[
"tail"
]],
"start_index"
:
relation
[
"start_index"
]
-
index
,
"end_index"
:
relation
[
"end_index"
]
-
index
,
})
item
.
update
({
"entities"
:
reformat
(
entities_in_this_span
),
"relations"
:
reformat
(
relations_in_this_span
),
})
item
[
'entities'
][
'label'
]
=
[
self
.
entities_labels
[
x
]
for
x
in
item
[
'entities'
][
'label'
]
]
encoded_inputs_all
.
append
(
item
)
return
encoded_inputs_all
def
_relations
(
self
,
entities
,
relations
,
id2label
,
empty_entity
,
entity_id_to_index_map
):
"""
build relations
"""
relations
=
list
(
set
(
relations
))
relations
=
[
rel
for
rel
in
relations
if
rel
[
0
]
not
in
empty_entity
and
rel
[
1
]
not
in
empty_entity
]
kv_relations
=
[]
for
rel
in
relations
:
pair
=
[
id2label
[
rel
[
0
]],
id2label
[
rel
[
1
]]]
if
pair
==
[
"question"
,
"answer"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
0
]],
"tail"
:
entity_id_to_index_map
[
rel
[
1
]]
})
elif
pair
==
[
"answer"
,
"question"
]:
kv_relations
.
append
({
"head"
:
entity_id_to_index_map
[
rel
[
1
]],
"tail"
:
entity_id_to_index_map
[
rel
[
0
]]
})
else
:
continue
relations
=
sorted
(
[{
"head"
:
rel
[
"head"
],
"tail"
:
rel
[
"tail"
],
"start_index"
:
get_relation_span
(
rel
,
entities
)[
0
],
"end_index"
:
get_relation_span
(
rel
,
entities
)[
1
],
}
for
rel
in
kv_relations
],
key
=
lambda
x
:
x
[
"head"
],
)
return
relations
def
load_img
(
self
,
image_path
):
# read img
img
=
cv2
.
imread
(
image_path
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
resize_h
,
resize_w
=
self
.
img_size
im_shape
=
img
.
shape
[
0
:
2
]
im_scale_y
=
resize_h
/
im_shape
[
0
]
im_scale_x
=
resize_w
/
im_shape
[
1
]
img_new
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
2
)
mean
=
np
.
array
([
0.485
,
0.456
,
0.406
])[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
([
0.229
,
0.224
,
0.225
])[
np
.
newaxis
,
np
.
newaxis
,
:]
img_new
=
img_new
/
255.0
img_new
-=
mean
img_new
/=
std
img
=
img_new
.
transpose
((
2
,
0
,
1
))
return
img
def
__getitem__
(
self
,
idx
):
if
self
.
load_mode
==
"all"
:
data
=
copy
.
deepcopy
(
self
.
encoded_inputs_all
[
idx
])
else
:
data
=
self
.
_parse_label_file
(
self
.
all_lines
[
idx
])[
0
]
image_path
=
data
.
pop
(
'image_path'
)
data
[
"image"
]
=
self
.
load_img
(
image_path
)
return_data
=
{}
for
k
,
v
in
data
.
items
():
if
k
in
self
.
return_keys
:
if
self
.
return_keys
[
k
][
'type'
]
==
'np'
:
v
=
np
.
array
(
v
,
dtype
=
self
.
return_keys
[
k
][
'dtype'
])
return_data
[
k
]
=
v
return
return_data
def
__len__
(
self
,
):
if
self
.
load_mode
==
"all"
:
return
len
(
self
.
encoded_inputs_all
)
else
:
return
len
(
self
.
all_lines
)
def
get_relation_span
(
rel
,
entities
):
bound
=
[]
for
entity_index
in
[
rel
[
"head"
],
rel
[
"tail"
]]:
bound
.
append
(
entities
[
entity_index
][
"start"
])
bound
.
append
(
entities
[
entity_index
][
"end"
])
return
min
(
bound
),
max
(
bound
)
def
reformat
(
data
):
new_data
=
{}
for
item
in
data
:
for
k
,
v
in
item
.
items
():
if
k
not
in
new_data
:
new_data
[
k
]
=
[]
new_data
[
k
].
append
(
v
)
return
new_data
requirements.txt
View file @
a323fce6
...
...
@@ -13,4 +13,3 @@ lxml
premailer
openpyxl
fasttext
==0.9.1
paddlenlp
>=2.2.1
tools/eval.py
View file @
a323fce6
...
...
@@ -61,7 +61,8 @@ def main():
else
:
model_type
=
None
best_model_dict
=
load_model
(
config
,
model
)
best_model_dict
=
load_model
(
config
,
model
,
model_type
=
config
[
'Architecture'
][
"model_type"
])
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
...
...
tools/infer_vqa_token_ser.py
0 → 100755
View file @
a323fce6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
json
import
paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.visual
import
draw_ser_results
from
ppocr.utils.utility
import
get_image_file_list
,
load_vqa_bio_label_maps
import
tools.program
as
program
def
to_tensor
(
data
):
import
numbers
from
collections
import
defaultdict
data_dict
=
defaultdict
(
list
)
to_tensor_idxs
=
[]
for
idx
,
v
in
enumerate
(
data
):
if
isinstance
(
v
,
(
np
.
ndarray
,
paddle
.
Tensor
,
numbers
.
Number
)):
if
idx
not
in
to_tensor_idxs
:
to_tensor_idxs
.
append
(
idx
)
data_dict
[
idx
].
append
(
v
)
for
idx
in
to_tensor_idxs
:
data_dict
[
idx
]
=
paddle
.
to_tensor
(
data_dict
[
idx
])
return
list
(
data_dict
.
values
())
class
SerPredictor
(
object
):
def
__init__
(
self
,
config
):
global_config
=
config
[
'Global'
]
# build post process
self
.
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
self
.
model
=
build_model
(
config
[
'Architecture'
])
load_model
(
config
,
self
.
model
,
model_type
=
config
[
'Architecture'
][
"model_type"
])
from
paddleocr
import
PaddleOCR
self
.
ocr_engine
=
PaddleOCR
(
use_angle_cls
=
False
,
show_log
=
False
)
# create data ops
transforms
=
[]
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
if
'Label'
in
op_name
:
op
[
op_name
][
'ocr_engine'
]
=
self
.
ocr_engine
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'input_ids'
,
'labels'
,
'bbox'
,
'image'
,
'attention_mask'
,
'token_type_ids'
,
'segment_offset_id'
,
'ocr_info'
,
'entities'
]
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
self
.
ops
=
create_operators
(
config
[
'Eval'
][
'dataset'
][
'transforms'
],
global_config
)
self
.
model
.
eval
()
def
__call__
(
self
,
img_path
):
with
open
(
img_path
,
'rb'
)
as
f
:
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
self
.
ops
)
batch
=
to_tensor
(
batch
)
preds
=
self
.
model
(
batch
)
post_result
=
self
.
post_process_class
(
preds
,
attention_masks
=
batch
[
4
],
segment_offset_ids
=
batch
[
6
],
ocr_infos
=
batch
[
7
])
return
post_result
,
batch
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
os
.
makedirs
(
config
[
'Global'
][
'save_res_path'
],
exist_ok
=
True
)
ser_engine
=
SerPredictor
(
config
)
infer_imgs
=
get_image_file_list
(
config
[
'Global'
][
'infer_img'
])
with
open
(
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
result
,
_
=
ser_engine
(
img_path
)
result
=
result
[
0
]
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ocr_info"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_ser_results
(
img_path
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
tools/infer_vqa_token_ser_re.py
0 → 100755
View file @
a323fce6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
json
import
paddle
import
paddle.distributed
as
dist
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.visual
import
draw_re_results
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
load_vqa_bio_label_maps
,
print_dict
from
tools.program
import
ArgsParser
,
load_config
,
merge_config
,
check_gpu
from
tools.infer_vqa_token_ser
import
SerPredictor
class
ReArgsParser
(
ArgsParser
):
def
__init__
(
self
):
super
(
ReArgsParser
,
self
).
__init__
()
self
.
add_argument
(
"-c_ser"
,
"--config_ser"
,
help
=
"ser configuration file to use"
)
self
.
add_argument
(
"-o_ser"
,
"--opt_ser"
,
nargs
=
'+'
,
help
=
"set ser configuration options "
)
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ReArgsParser
,
self
).
parse_args
(
argv
)
assert
args
.
config_ser
is
not
None
,
\
"Please specify --config_ser=ser_configure_file_path."
args
.
opt_ser
=
self
.
_parse_opt
(
args
.
opt_ser
)
return
args
def
make_input
(
ser_inputs
,
ser_results
):
entities_labels
=
{
'HEADER'
:
0
,
'QUESTION'
:
1
,
'ANSWER'
:
2
}
entities
=
ser_inputs
[
8
][
0
]
ser_results
=
ser_results
[
0
]
assert
len
(
entities
)
==
len
(
ser_results
)
# entities
start
=
[]
end
=
[]
label
=
[]
entity_idx_dict
=
{}
for
i
,
(
res
,
entity
)
in
enumerate
(
zip
(
ser_results
,
entities
)):
if
res
[
'pred'
]
==
'O'
:
continue
entity_idx_dict
[
len
(
start
)]
=
i
start
.
append
(
entity
[
'start'
])
end
.
append
(
entity
[
'end'
])
label
.
append
(
entities_labels
[
res
[
'pred'
]])
entities
=
dict
(
start
=
start
,
end
=
end
,
label
=
label
)
# relations
head
=
[]
tail
=
[]
for
i
in
range
(
len
(
entities
[
"label"
])):
for
j
in
range
(
len
(
entities
[
"label"
])):
if
entities
[
"label"
][
i
]
==
1
and
entities
[
"label"
][
j
]
==
2
:
head
.
append
(
i
)
tail
.
append
(
j
)
relations
=
dict
(
head
=
head
,
tail
=
tail
)
batch_size
=
ser_inputs
[
0
].
shape
[
0
]
entities_batch
=
[]
relations_batch
=
[]
entity_idx_dict_batch
=
[]
for
b
in
range
(
batch_size
):
entities_batch
.
append
(
entities
)
relations_batch
.
append
(
relations
)
entity_idx_dict_batch
.
append
(
entity_idx_dict
)
ser_inputs
[
8
]
=
entities_batch
ser_inputs
.
append
(
relations_batch
)
ser_inputs
.
pop
(
7
)
ser_inputs
.
pop
(
6
)
ser_inputs
.
pop
(
1
)
return
ser_inputs
,
entity_idx_dict_batch
class
SerRePredictor
(
object
):
def
__init__
(
self
,
config
,
ser_config
):
self
.
ser_engine
=
SerPredictor
(
ser_config
)
# init re model
global_config
=
config
[
'Global'
]
# build post process
self
.
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
global_config
)
# build model
self
.
model
=
build_model
(
config
[
'Architecture'
])
load_model
(
config
,
self
.
model
,
model_type
=
config
[
'Architecture'
][
"model_type"
])
self
.
model
.
eval
()
def
__call__
(
self
,
img_path
):
ser_results
,
ser_inputs
=
self
.
ser_engine
(
img_path
)
paddle
.
save
(
ser_inputs
,
'ser_inputs.npy'
)
paddle
.
save
(
ser_results
,
'ser_results.npy'
)
re_input
,
entity_idx_dict_batch
=
make_input
(
ser_inputs
,
ser_results
)
preds
=
self
.
model
(
re_input
)
post_result
=
self
.
post_process_class
(
preds
,
ser_results
=
ser_results
,
entity_idx_dict_batch
=
entity_idx_dict_batch
)
return
post_result
def
preprocess
():
FLAGS
=
ReArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
config
=
merge_config
(
config
,
FLAGS
.
opt
)
ser_config
=
load_config
(
FLAGS
.
config_ser
)
ser_config
=
merge_config
(
ser_config
,
FLAGS
.
opt_ser
)
logger
=
get_logger
(
name
=
'root'
)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu
=
config
[
'Global'
][
'use_gpu'
]
check_gpu
(
use_gpu
)
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
logger
.
info
(
'{} re config {}'
.
format
(
'*'
*
10
,
'*'
*
10
))
print_dict
(
config
,
logger
)
logger
.
info
(
'
\n
'
)
logger
.
info
(
'{} ser config {}'
.
format
(
'*'
*
10
,
'*'
*
10
))
print_dict
(
ser_config
,
logger
)
logger
.
info
(
'train with paddle {} and device {}'
.
format
(
paddle
.
__version__
,
device
))
return
config
,
ser_config
,
device
,
logger
if
__name__
==
'__main__'
:
config
,
ser_config
,
device
,
logger
=
preprocess
()
os
.
makedirs
(
config
[
'Global'
][
'save_res_path'
],
exist_ok
=
True
)
ser_re_engine
=
SerRePredictor
(
config
,
ser_config
)
infer_imgs
=
get_image_file_list
(
config
[
'Global'
][
'infer_img'
])
with
open
(
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
"infer_results.txt"
),
"w"
,
encoding
=
'utf-8'
)
as
fout
:
for
idx
,
img_path
in
enumerate
(
infer_imgs
):
save_img_path
=
os
.
path
.
join
(
config
[
'Global'
][
'save_res_path'
],
os
.
path
.
splitext
(
os
.
path
.
basename
(
img_path
))[
0
]
+
"_ser.jpg"
)
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
result
=
ser_re_engine
(
img_path
)
result
=
result
[
0
]
fout
.
write
(
img_path
+
"
\t
"
+
json
.
dumps
(
{
"ser_resule"
:
result
,
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_re_results
(
img_path
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
tools/program.py
View file @
a323fce6
...
...
@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
return
config
class
AttrDict
(
dict
):
"""Single level attribute dict, NOT recursive"""
def
__init__
(
self
,
**
kwargs
):
super
(
AttrDict
,
self
).
__init__
()
super
(
AttrDict
,
self
).
update
(
kwargs
)
def
__getattr__
(
self
,
key
):
if
key
in
self
:
return
self
[
key
]
raise
AttributeError
(
"object has no attribute '{}'"
.
format
(
key
))
global_config
=
AttrDict
()
default_config
=
{
'Global'
:
{
'debug'
:
False
,
}}
def
load_config
(
file_path
):
"""
Load config from yml/yaml file.
...
...
@@ -94,38 +76,38 @@ def load_config(file_path):
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
merge_config
(
default_config
)
_
,
ext
=
os
.
path
.
splitext
(
file_path
)
assert
ext
in
[
'.yml'
,
'.yaml'
],
"only support yaml files for now"
merge_
config
(
yaml
.
load
(
open
(
file_path
,
'rb'
),
Loader
=
yaml
.
Loader
)
)
return
global_
config
config
=
yaml
.
load
(
open
(
file_path
,
'rb'
),
Loader
=
yaml
.
Loader
)
return
config
def
merge_config
(
config
):
def
merge_config
(
config
,
opts
):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for
key
,
value
in
config
.
items
():
for
key
,
value
in
opts
.
items
():
if
"."
not
in
key
:
if
isinstance
(
value
,
dict
)
and
key
in
global_
config
:
global_
config
[
key
].
update
(
value
)
if
isinstance
(
value
,
dict
)
and
key
in
config
:
config
[
key
].
update
(
value
)
else
:
global_
config
[
key
]
=
value
config
[
key
]
=
value
else
:
sub_keys
=
key
.
split
(
'.'
)
assert
(
sub_keys
[
0
]
in
global_
config
sub_keys
[
0
]
in
config
),
"the sub_keys can only be one of global_config: {}, but get: {}, please check your running command"
.
format
(
global_
config
.
keys
(),
sub_keys
[
0
])
cur
=
global_
config
[
sub_keys
[
0
]]
config
.
keys
(),
sub_keys
[
0
])
cur
=
config
[
sub_keys
[
0
]]
for
idx
,
sub_key
in
enumerate
(
sub_keys
[
1
:]):
if
idx
==
len
(
sub_keys
)
-
2
:
cur
[
sub_key
]
=
value
else
:
cur
=
cur
[
sub_key
]
return
config
def
check_gpu
(
use_gpu
):
...
...
@@ -204,20 +186,24 @@ def train(config,
model_type
=
None
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
else
:
start_epoch
=
1
start_epoch
=
best_model_dict
[
'start_epoch'
]
if
'start_epoch'
in
best_model_dict
else
1
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
for
epoch
in
range
(
start_epoch
,
epoch_num
+
1
):
train_dataloader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
,
seed
=
epoch
)
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
reader_start
=
time
.
time
()
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
if
train_dataloader
.
dataset
.
need_reset
:
train_dataloader
=
build_dataloader
(
config
,
'Train'
,
device
,
logger
,
seed
=
epoch
)
max_iter
=
len
(
train_dataloader
)
-
1
if
platform
.
system
(
)
==
"Windows"
else
len
(
train_dataloader
)
for
idx
,
batch
in
enumerate
(
train_dataloader
):
profiler
.
add_profiler_step
(
profiler_options
)
train_reader_cost
+=
time
.
time
()
-
reader_start
...
...
@@ -239,10 +225,11 @@ def train(config,
else
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
elif
model_type
==
"kie"
:
elif
model_type
in
[
"kie"
,
'vqa'
]
:
preds
=
model
(
batch
)
else
:
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
...
...
@@ -256,6 +243,7 @@ def train(config,
optimizer
.
clear_grad
()
train_run_cost
+=
time
.
time
()
-
train_start
global_step
+=
1
total_samples
+=
len
(
images
)
if
not
isinstance
(
lr_scheduler
,
float
):
...
...
@@ -285,12 +273,13 @@ def train(config,
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
(
idx
>=
len
(
train_dataloader
)
-
1
)):
logs
=
train_stats
.
log
()
strs
=
'epoch: [{}/{}],
i
te
r
: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'
.
format
(
strs
=
'epoch: [{}/{}],
global_s
te
p
: {}, {},
avg_
reader_cost: {:.5f} s,
avg_
batch_cost: {:.5f} s,
avg_
samples: {}, ips: {:.5f}'
.
format
(
epoch
,
epoch_num
,
global_step
,
logs
,
train_reader_cost
/
print_batch_step
,
(
train_reader_cost
+
train_run_cost
)
/
print_batch_step
,
total_samples
,
print_batch_step
,
total_samples
/
print_batch_step
,
total_samples
/
(
train_reader_cost
+
train_run_cost
))
logger
.
info
(
strs
)
train_reader_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
...
...
@@ -330,6 +319,7 @@ def train(config,
optimizer
,
save_model_dir
,
logger
,
config
,
is_best
=
True
,
prefix
=
'best_accuracy'
,
best_model_dict
=
best_model_dict
,
...
...
@@ -344,8 +334,7 @@ def train(config,
vdl_writer
.
add_scalar
(
'EVAL/best_{}'
.
format
(
main_indicator
),
best_model_dict
[
main_indicator
],
global_step
)
global_step
+=
1
optimizer
.
clear_grad
()
reader_start
=
time
.
time
()
if
dist
.
get_rank
()
==
0
:
save_model
(
...
...
@@ -353,6 +342,7 @@ def train(config,
optimizer
,
save_model_dir
,
logger
,
config
,
is_best
=
False
,
prefix
=
'latest'
,
best_model_dict
=
best_model_dict
,
...
...
@@ -364,6 +354,7 @@ def train(config,
optimizer
,
save_model_dir
,
logger
,
config
,
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
),
best_model_dict
=
best_model_dict
,
...
...
@@ -401,19 +392,28 @@ def eval(model,
start
=
time
.
time
()
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
elif
model_type
==
"kie"
:
elif
model_type
in
[
"kie"
,
'vqa'
]
:
preds
=
model
(
batch
)
else
:
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch_numpy
=
[]
for
item
in
batch
:
if
isinstance
(
item
,
paddle
.
Tensor
):
batch_numpy
.
append
(
item
.
numpy
())
else
:
batch_numpy
.
append
(
item
)
# Obtain usable results from post-processing methods
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
if
model_type
in
[
'table'
,
'kie'
]:
eval_class
(
preds
,
batch
)
eval_class
(
preds
,
batch_numpy
)
elif
model_type
in
[
'vqa'
]:
post_result
=
post_process_class
(
preds
,
batch_numpy
)
eval_class
(
post_result
,
batch_numpy
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
eval_class
(
post_result
,
batch
)
post_result
=
post_process_class
(
preds
,
batch
_numpy
[
1
])
eval_class
(
post_result
,
batch
_numpy
)
pbar
.
update
(
1
)
total_frame
+=
len
(
images
)
...
...
@@ -479,9 +479,9 @@ def preprocess(is_train=False):
FLAGS
=
ArgsParser
().
parse_args
()
profiler_options
=
FLAGS
.
profiler_options
config
=
load_config
(
FLAGS
.
config
)
merge_config
(
FLAGS
.
opt
)
config
=
merge_config
(
config
,
FLAGS
.
opt
)
profile_dic
=
{
"profiler_options"
:
FLAGS
.
profiler_options
}
merge_config
(
profile_dic
)
config
=
merge_config
(
config
,
profile_dic
)
if
is_train
:
# save_config
...
...
@@ -503,13 +503,8 @@ def preprocess(is_train=False):
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
]
windows_not_support_list
=
[
'PSE'
]
if
platform
.
system
()
==
"Windows"
and
alg
in
windows_not_support_list
:
logger
.
warning
(
'{} is not support in Windows now'
.
format
(
windows_not_support_list
))
sys
.
exit
()
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
paddle
.
set_device
(
device
)
...
...
tools/train.py
View file @
a323fce6
...
...
@@ -97,7 +97,8 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
)
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
,
config
[
'Architecture'
][
"model_type"
])
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment