Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
54a31f50
Commit
54a31f50
authored
Oct 05, 2019
by
jinoobaek-qz
Committed by
Lysandre Debut
Oct 09, 2019
Browse files
Add save_total_limit
parent
1c507995
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
0 deletions
+22
-0
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+22
-0
No files found.
examples/run_lm_finetuning.py
View file @
54a31f50
...
...
@@ -27,6 +27,8 @@ import logging
import
os
import
pickle
import
random
import
re
import
shutil
import
numpy
as
np
import
torch
...
...
@@ -222,6 +224,24 @@ def train(args, train_dataset, model, tokenizer):
logging_loss
=
tr_loss
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
save_steps
>
0
and
global_step
%
args
.
save_steps
==
0
:
if
args
.
save_total_limit
and
args
.
save_total_limit
>
0
:
# Check if we should delete older checkpoint(s)
glob_checkpoints
=
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
'checkpoint-*'
))
if
len
(
glob_checkpoints
)
+
1
>
args
.
save_total_limit
:
checkpoints_sorted
=
[]
for
path
in
glob_checkpoints
:
regex_match
=
re
.
match
(
'.*checkpoint-([0-9]+)'
,
path
)
if
regex_match
and
regex_match
.
groups
():
checkpoints_sorted
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
checkpoints_sorted
=
sorted
(
checkpoints_sorted
)
checkpoints_sorted
=
[
checkpoint
[
1
]
for
checkpoint
in
checkpoints_sorted
]
number_of_checkpoints_to_delete
=
max
(
0
,
len
(
checkpoints_sorted
)
+
1
-
args
.
save_total_limit
)
checkpoints_to_be_deleted
=
checkpoints_sorted
[:
number_of_checkpoints_to_delete
]
for
checkpoint
in
checkpoints_to_be_deleted
:
logger
.
info
(
"Deleting older checkpoint [{}] due to args.save_total_limit"
.
format
(
checkpoint
))
shutil
.
rmtree
(
checkpoint
)
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
'checkpoint-{}'
.
format
(
global_step
))
if
not
os
.
path
.
exists
(
output_dir
):
...
...
@@ -359,6 +379,8 @@ def main():
help
=
"Log every X updates steps."
)
parser
.
add_argument
(
'--save_steps'
,
type
=
int
,
default
=
50
,
help
=
"Save checkpoint every X updates steps."
)
parser
.
add_argument
(
'--save_total_limit'
,
type
=
int
,
default
=
None
,
help
=
'Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default'
)
parser
.
add_argument
(
"--eval_all_checkpoints"
,
action
=
'store_true'
,
help
=
"Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number"
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
'store_true'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment