Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
68deaab1
Unverified
Commit
68deaab1
authored
Dec 19, 2021
by
MissPenguin
Committed by
GitHub
Dec 19, 2021
Browse files
Merge pull request #4967 from WenmuZhou/fix_vqa
ser and re support distributed train
parents
57f01253
3ffaf7f2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
78 deletions
+74
-78
ppstructure/vqa/helper/eval_with_label_end2end.py
ppstructure/vqa/helper/eval_with_label_end2end.py
+3
-6
ppstructure/vqa/train_re.py
ppstructure/vqa/train_re.py
+34
-31
ppstructure/vqa/train_ser.py
ppstructure/vqa/train_ser.py
+37
-41
No files found.
ppstructure/vqa/helper/eval_with_label_end2end.py
View file @
68deaab1
...
@@ -15,13 +15,12 @@
...
@@ -15,13 +15,12 @@
import
os
import
os
import
re
import
re
import
sys
import
sys
# import Polygon
import
shapely
import
shapely
from
shapely.geometry
import
Polygon
from
shapely.geometry
import
Polygon
import
numpy
as
np
import
numpy
as
np
from
collections
import
defaultdict
from
collections
import
defaultdict
import
operator
import
operator
import
editdistance
import
Levenshtein
import
argparse
import
argparse
import
json
import
json
import
copy
import
copy
...
@@ -95,7 +94,7 @@ def ed(args, str1, str2):
...
@@ -95,7 +94,7 @@ def ed(args, str1, str2):
if
args
.
ignore_case
:
if
args
.
ignore_case
:
str1
=
str1
.
lower
()
str1
=
str1
.
lower
()
str2
=
str2
.
lower
()
str2
=
str2
.
lower
()
return
edit
distance
.
eval
(
str1
,
str2
)
return
Levenshtein
.
distance
(
str1
,
str2
)
def
convert_bbox_to_polygon
(
bbox
):
def
convert_bbox_to_polygon
(
bbox
):
...
@@ -115,8 +114,6 @@ def eval_e2e(args):
...
@@ -115,8 +114,6 @@ def eval_e2e(args):
# pred
# pred
dt_results
=
parse_ser_results_fp
(
args
.
pred_json_path
,
"pred"
,
dt_results
=
parse_ser_results_fp
(
args
.
pred_json_path
,
"pred"
,
args
.
ignore_background
)
args
.
ignore_background
)
assert
set
(
gt_results
.
keys
())
==
set
(
dt_results
.
keys
())
iou_thresh
=
args
.
iou_thres
iou_thresh
=
args
.
iou_thres
num_gt_chars
=
0
num_gt_chars
=
0
gt_count
=
0
gt_count
=
0
...
@@ -124,7 +121,7 @@ def eval_e2e(args):
...
@@ -124,7 +121,7 @@ def eval_e2e(args):
hit
=
0
hit
=
0
ed_sum
=
0
ed_sum
=
0
for
img_name
in
g
t_results
:
for
img_name
in
d
t_results
:
gt_info
=
gt_results
[
img_name
]
gt_info
=
gt_results
[
img_name
]
gt_count
+=
len
(
gt_info
)
gt_count
+=
len
(
gt_info
)
...
...
ppstructure/vqa/train_re.py
View file @
68deaab1
...
@@ -36,6 +36,9 @@ from ppocr.utils.logging import get_logger
...
@@ -36,6 +36,9 @@ from ppocr.utils.logging import get_logger
def
train
(
args
):
def
train
(
args
):
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
print_arguments
(
args
,
logger
)
print_arguments
(
args
,
logger
)
# Added here for reproducibility (even between python 2 and 3)
# Added here for reproducibility (even between python 2 and 3)
...
@@ -45,7 +48,7 @@ def train(args):
...
@@ -45,7 +48,7 @@ def train(args):
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
# dist mode
# dist mode
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
...
@@ -59,8 +62,8 @@ def train(args):
...
@@ -59,8 +62,8 @@ def train(args):
args
.
model_name_or_path
)
args
.
model_name_or_path
)
# dist mode
# dist mode
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
distributed
:
model
=
paddle
.
distributed
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
train_dataset
=
XFUNDataset
(
tokenizer
,
tokenizer
,
...
@@ -90,8 +93,7 @@ def train(args):
...
@@ -90,8 +93,7 @@ def train(args):
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
\
max
(
1
,
paddle
.
distributed
.
get_world_size
())
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
train_dataset
,
batch_sampler
=
train_sampler
,
batch_sampler
=
train_sampler
,
...
@@ -136,7 +138,8 @@ def train(args):
...
@@ -136,7 +138,8 @@ def train(args):
args
.
per_gpu_train_batch_size
))
args
.
per_gpu_train_batch_size
))
logger
.
info
(
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = {}"
.
" Total train batch size (w. parallel, distributed & accumulation) = {}"
.
format
(
args
.
train_batch_size
*
paddle
.
distributed
.
get_world_size
()))
format
(
args
.
per_gpu_train_batch_size
*
paddle
.
distributed
.
get_world_size
()))
logger
.
info
(
" Total optimization steps = {}"
.
format
(
t_total
))
logger
.
info
(
" Total optimization steps = {}"
.
format
(
t_total
))
global_step
=
0
global_step
=
0
...
@@ -170,7 +173,7 @@ def train(args):
...
@@ -170,7 +173,7 @@ def train(args):
global_step
+=
1
global_step
+=
1
total_samples
+=
batch
[
'image'
].
shape
[
0
]
total_samples
+=
batch
[
'image'
].
shape
[
0
]
if
step
%
print_step
==
0
:
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch
,
args
.
num_train_epochs
,
step
,
format
(
epoch
,
args
.
num_train_epochs
,
step
,
...
@@ -185,38 +188,38 @@ def train(args):
...
@@ -185,38 +188,38 @@ def train(args):
train_run_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
total_samples
=
0
if
(
paddle
.
distributed
.
get_rank
()
==
0
and
args
.
eval_steps
>
0
and
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
global_step
%
args
.
eval_steps
==
0
):
# Log metrics
# Log metrics
if
(
paddle
.
distributed
.
get_rank
()
==
0
and
args
.
# Only evaluate when single GPU otherwise metrics may not average well
evaluate_during_training
):
# Only evaluate when single GPU otherwise metrics may not average well
results
=
evaluate
(
model
,
eval_dataloader
,
logger
)
results
=
evaluate
(
model
,
eval_dataloader
,
logger
)
if
results
[
'f1'
]
>=
best_metirc
[
'f1'
]:
if
results
[
'f1'
]
>=
best_metirc
[
'f1'
]:
best_metirc
=
results
best_metirc
=
results
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"eval results: {}"
.
format
(
results
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
if
paddle
.
distributed
.
get_rank
()
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
paddle
.
distributed
.
get_rank
()
==
0
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
output_dir
))
logger
.
info
(
"eval results: {}"
.
format
(
results
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to {}"
.
format
(
output_dir
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
logger
.
info
(
"best_metirc: {}"
.
format
(
best_metirc
))
...
...
ppstructure/vqa/train_ser.py
View file @
68deaab1
...
@@ -37,6 +37,9 @@ from ppocr.utils.logging import get_logger
...
@@ -37,6 +37,9 @@ from ppocr.utils.logging import get_logger
def
train
(
args
):
def
train
(
args
):
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
rank
=
paddle
.
distributed
.
get_rank
()
distributed
=
paddle
.
distributed
.
get_world_size
()
>
1
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
logger
=
get_logger
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"train.log"
))
print_arguments
(
args
,
logger
)
print_arguments
(
args
,
logger
)
...
@@ -44,7 +47,7 @@ def train(args):
...
@@ -44,7 +47,7 @@ def train(args):
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
pad_token_label_id
=
paddle
.
nn
.
CrossEntropyLoss
().
ignore_index
# dist mode
# dist mode
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
distributed
:
paddle
.
distributed
.
init_parallel_env
()
paddle
.
distributed
.
init_parallel_env
()
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
LayoutXLMTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
...
@@ -59,7 +62,7 @@ def train(args):
...
@@ -59,7 +62,7 @@ def train(args):
args
.
model_name_or_path
)
args
.
model_name_or_path
)
# dist mode
# dist mode
if
paddle
.
distributed
.
get_world_size
()
>
1
:
if
distributed
:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
train_dataset
=
XFUNDataset
(
train_dataset
=
XFUNDataset
(
...
@@ -88,9 +91,6 @@ def train(args):
...
@@ -88,9 +91,6 @@ def train(args):
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_sampler
=
paddle
.
io
.
DistributedBatchSampler
(
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
train_dataset
,
batch_size
=
args
.
per_gpu_train_batch_size
,
shuffle
=
True
)
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
paddle
.
distributed
.
get_world_size
())
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataloader
=
paddle
.
io
.
DataLoader
(
train_dataset
,
train_dataset
,
batch_sampler
=
train_sampler
,
batch_sampler
=
train_sampler
,
...
@@ -134,7 +134,7 @@ def train(args):
...
@@ -134,7 +134,7 @@ def train(args):
args
.
per_gpu_train_batch_size
)
args
.
per_gpu_train_batch_size
)
logger
.
info
(
logger
.
info
(
" Total train batch size (w. parallel, distributed) = %d"
,
" Total train batch size (w. parallel, distributed) = %d"
,
args
.
train_batch_size
*
paddle
.
distributed
.
get_world_size
(),
)
args
.
per_gpu_
train_batch_size
*
paddle
.
distributed
.
get_world_size
(),
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
global_step
=
0
...
@@ -168,7 +168,7 @@ def train(args):
...
@@ -168,7 +168,7 @@ def train(args):
global_step
+=
1
global_step
+=
1
total_samples
+=
batch
[
'image'
].
shape
[
0
]
total_samples
+=
batch
[
'image'
].
shape
[
0
]
if
step
%
print_step
==
0
:
if
rank
==
0
and
step
%
print_step
==
0
:
logger
.
info
(
logger
.
info
(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
...
@@ -183,47 +183,43 @@ def train(args):
...
@@ -183,47 +183,43 @@ def train(args):
train_run_cost
=
0.0
train_run_cost
=
0.0
total_samples
=
0
total_samples
=
0
if
(
paddle
.
distributed
.
get_rank
()
==
0
and
args
.
eval_steps
>
0
and
if
rank
==
0
and
args
.
eval_steps
>
0
and
global_step
%
args
.
eval_steps
==
0
and
args
.
evaluate_during_training
:
global_step
%
args
.
eval_steps
==
0
):
# Log metrics
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
# Only evaluate when single GPU otherwise metrics may not average well
if
paddle
.
distributed
.
get_rank
(
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
eval_dataloader
,
)
==
0
and
args
.
evaluate_during_training
:
label2id_map
,
id2label_map
,
results
,
_
=
evaluate
(
pad_token_label_id
,
logger
)
args
,
model
,
tokenizer
,
eval_dataloader
,
label2id_map
,
id2label_map
,
pad_token_label_id
,
logger
)
if
best_metrics
is
None
or
results
[
"f1"
]
>=
best_metrics
[
"f1"
]:
best_metrics
=
copy
.
deepcopy
(
results
)
if
best_metrics
is
None
or
results
[
"f1"
]
>=
best_metrics
[
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
"f1"
]:
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
best_metrics
=
copy
.
deepcopy
(
results
)
if
distributed
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"best_model"
)
model
.
_layers
.
save_pretrained
(
output_dir
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
else
:
if
paddle
.
distributed
.
get_rank
()
==
0
:
model
.
save_pretrained
(
output_dir
)
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"[epoch {}/{}][iter: {}/{}] results: {}"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
results
))
if
best_metrics
is
not
None
:
logger
.
info
(
"best metrics: {}"
.
format
(
best_metrics
))
if
paddle
.
distributed
.
get_rank
()
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
paddle
.
distributed
.
get_rank
()
==
0
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"[epoch {}/{}][iter: {}/{}] results: {}"
.
format
(
epoch_id
,
args
.
num_train_epochs
,
step
,
len
(
train_dataloader
),
results
))
if
best_metrics
is
not
None
:
logger
.
info
(
"best metrics: {}"
.
format
(
best_metrics
))
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
if
rank
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"latest_model"
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
if
distributed
:
model
.
_layers
.
save_pretrained
(
output_dir
)
else
:
model
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
paddle
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
return
global_step
,
tr_loss
/
global_step
return
global_step
,
tr_loss
/
global_step
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment