Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
528d3f32
Commit
528d3f32
authored
Oct 07, 2019
by
jinoobaek-qz
Committed by
Lysandre Debut
Oct 09, 2019
Browse files
Improve readability and improve make less assumptions about checkpoint format
parent
56301bd9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
18 deletions
+26
-18
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+26
-18
No files found.
examples/run_lm_finetuning.py
View file @
528d3f32
...
@@ -106,15 +106,22 @@ def set_seed(args):
...
@@ -106,15 +106,22 @@ def set_seed(args):
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
def
rotate_checkpoints
(
args
):
def
_rotate_checkpoints
(
args
,
checkpoint_prefix
,
use_mtime
=
False
):
if
args
.
save_total_limit
and
args
.
save_total_limit
>
0
:
if
not
args
.
save_total_limit
:
return
if
args
.
save_total_limit
<=
0
:
return
# Check if we should delete older checkpoint(s)
# Check if we should delete older checkpoint(s)
glob_checkpoints
=
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
'checkpoint
-*'
))
glob_checkpoints
=
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
'
{}-*'
.
format
(
checkpoint
_prefix
)
))
if
len
(
glob_checkpoints
)
>
args
.
save_total_limit
:
if
len
(
glob_checkpoints
)
>
args
.
save_total_limit
:
checkpoints_sorted
=
[]
checkpoints_sorted
=
[]
for
path
in
glob_checkpoints
:
for
path
in
glob_checkpoints
:
regex_match
=
re
.
match
(
'.*
checkpoint-([0-9]+)'
,
path
)
regex_match
=
re
.
match
(
'.*
{}-([0-9]+)'
.
format
(
checkpoint_prefix
)
,
path
)
if
regex_match
and
regex_match
.
groups
():
if
regex_match
and
regex_match
.
groups
():
if
use_mtime
:
checkpoints_sorted
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
else
:
checkpoints_sorted
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
checkpoints_sorted
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
checkpoints_sorted
=
sorted
(
checkpoints_sorted
)
checkpoints_sorted
=
sorted
(
checkpoints_sorted
)
...
@@ -244,8 +251,9 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -244,8 +251,9 @@ def train(args, train_dataset, model, tokenizer):
logging_loss
=
tr_loss
logging_loss
=
tr_loss
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
save_steps
>
0
and
global_step
%
args
.
save_steps
==
0
:
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
save_steps
>
0
and
global_step
%
args
.
save_steps
==
0
:
checkpoint_prefix
=
'checkpoint'
# Save model checkpoint
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
'
checkpoint
-{}'
.
format
(
global_step
))
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
'
{}
-{}'
.
format
(
checkpoint_prefix
,
global_step
))
if
not
os
.
path
.
exists
(
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
os
.
makedirs
(
output_dir
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Take care of distributed/parallel training
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Take care of distributed/parallel training
...
@@ -253,7 +261,7 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -253,7 +261,7 @@ def train(args, train_dataset, model, tokenizer):
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
rotate_checkpoints
(
args
)
_
rotate_checkpoints
(
args
,
checkpoint_prefix
)
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
epoch_iterator
.
close
()
epoch_iterator
.
close
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment