Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a515caa3
Unverified
Commit
a515caa3
authored
May 18, 2021
by
Sylvain Gugger
Committed by
GitHub
May 18, 2021
Browse files
Fix checkpoint deletion (#11748)
parent
b88e0e01
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
10 deletions
+49
-10
src/transformers/trainer.py
src/transformers/trainer.py
+18
-10
tests/test_trainer.py
tests/test_trainer.py
+31
-0
No files found.
src/transformers/trainer.py
View file @
a515caa3
...
...
@@ -1523,10 +1523,6 @@ class Trainer:
if
self
.
is_world_process_zero
():
self
.
state
.
save_to_json
(
os
.
path
.
join
(
output_dir
,
"trainer_state.json"
))
# Maybe delete some older checkpoints.
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
,
output_dir
=
run_dir
)
# Save RNG state in non-distributed training
rng_states
=
{
"python"
:
random
.
getstate
(),
...
...
@@ -1552,6 +1548,10 @@ class Trainer:
else
:
torch
.
save
(
rng_states
,
os
.
path
.
join
(
output_dir
,
f
"rng_state_
{
local_rank
}
.pth"
))
# Maybe delete some older checkpoints.
if
self
.
is_world_process_zero
():
self
.
_rotate_checkpoints
(
use_mtime
=
True
,
output_dir
=
run_dir
)
def
_load_optimizer_and_scheduler
(
self
,
checkpoint
):
"""If optimizer and scheduler states exist, load them."""
if
checkpoint
is
None
:
...
...
@@ -1924,7 +1924,7 @@ class Trainer:
ordering_and_checkpoint_path
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
else
:
regex_match
=
re
.
match
(
f
".*
{
checkpoint_prefix
}
-([0-9]+)"
,
path
)
if
regex_match
and
regex_match
.
groups
():
if
regex_match
is
not
None
and
regex_match
.
groups
()
is
not
None
:
ordering_and_checkpoint_path
.
append
((
int
(
regex_match
.
groups
()[
0
]),
path
))
checkpoints_sorted
=
sorted
(
ordering_and_checkpoint_path
)
...
...
@@ -1932,10 +1932,8 @@ class Trainer:
# Make sure we don't delete the best model.
if
self
.
state
.
best_model_checkpoint
is
not
None
:
best_model_index
=
checkpoints_sorted
.
index
(
str
(
Path
(
self
.
state
.
best_model_checkpoint
)))
checkpoints_sorted
[
best_model_index
],
checkpoints_sorted
[
-
1
]
=
(
checkpoints_sorted
[
-
1
],
checkpoints_sorted
[
best_model_index
],
)
for
i
in
range
(
best_model_index
,
len
(
checkpoints_sorted
)
-
2
):
checkpoints_sorted
[
i
],
checkpoints_sorted
[
i
+
1
]
=
checkpoints_sorted
[
i
+
1
],
checkpoints_sorted
[
i
]
return
checkpoints_sorted
def
_rotate_checkpoints
(
self
,
use_mtime
=
False
,
output_dir
=
None
)
->
None
:
...
...
@@ -1947,7 +1945,17 @@ class Trainer:
if
len
(
checkpoints_sorted
)
<=
self
.
args
.
save_total_limit
:
return
number_of_checkpoints_to_delete
=
max
(
0
,
len
(
checkpoints_sorted
)
-
self
.
args
.
save_total_limit
)
# If save_total_limit=1 with load_best_mode_at_end=True, we could end up deleting the last checkpoint, which
# we don't do to allow resuming.
save_total_limit
=
self
.
args
.
save_total_limit
if
(
self
.
state
.
best_model_checkpoint
is
not
None
and
self
.
args
.
save_total_limit
==
1
and
checkpoints_sorted
[
-
1
]
!=
self
.
state
.
best_model_checkpoint
):
save_total_limit
=
2
number_of_checkpoints_to_delete
=
max
(
0
,
len
(
checkpoints_sorted
)
-
save_total_limit
)
checkpoints_to_be_deleted
=
checkpoints_sorted
[:
number_of_checkpoints_to_delete
]
for
checkpoint
in
checkpoints_to_be_deleted
:
logger
.
info
(
f
"Deleting older checkpoint [
{
checkpoint
}
] due to args.save_total_limit"
)
...
...
tests/test_trainer.py
View file @
a515caa3
...
...
@@ -21,6 +21,7 @@ import random
import
re
import
tempfile
import
unittest
from
pathlib
import
Path
import
numpy
as
np
...
...
@@ -45,6 +46,7 @@ from transformers.testing_utils import (
require_torch_multi_gpu
,
slow
,
)
from
transformers.trainer_utils
import
PREFIX_CHECKPOINT_DIR
from
transformers.utils.hp_naming
import
TrialShortNamer
...
...
@@ -1048,6 +1050,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer
.
train
()
self
.
assertTrue
(
isinstance
(
trainer
.
state
.
total_flos
,
float
))
def
check_checkpoint_deletion
(
self
,
trainer
,
output_dir
,
expected
):
# Make fake checkpoints
for
n
in
[
5
,
10
,
15
,
20
,
25
]:
os
.
makedirs
(
os
.
path
.
join
(
output_dir
,
f
"
{
PREFIX_CHECKPOINT_DIR
}
-
{
n
}
"
),
exist_ok
=
True
)
trainer
.
_rotate_checkpoints
(
output_dir
=
output_dir
)
glob_checkpoints
=
[
str
(
x
)
for
x
in
Path
(
output_dir
).
glob
(
f
"
{
PREFIX_CHECKPOINT_DIR
}
-*"
)]
values
=
[
int
(
re
.
match
(
f
".*
{
PREFIX_CHECKPOINT_DIR
}
-([0-9]+)"
,
d
).
groups
()[
0
])
for
d
in
glob_checkpoints
]
self
.
assertSetEqual
(
set
(
values
),
set
(
expected
))
def
test_checkpoint_rotation
(
self
):
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
# Without best model at end
trainer
=
get_regression_trainer
(
output_dir
=
tmp_dir
,
save_total_limit
=
2
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
20
,
25
])
# With best model at end
trainer
=
get_regression_trainer
(
output_dir
=
tmp_dir
,
load_best_model_at_end
=
True
,
save_total_limit
=
2
)
trainer
.
state
.
best_model_checkpoint
=
os
.
path
.
join
(
tmp_dir
,
"checkpoint-5"
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
5
,
25
])
# Edge case: we don't always honor save_total_limit=1 if load_best_model_at_end=True to be able to resume
# from checkpoint
trainer
=
get_regression_trainer
(
output_dir
=
tmp_dir
,
load_best_model_at_end
=
True
,
save_total_limit
=
1
)
trainer
.
state
.
best_model_checkpoint
=
os
.
path
.
join
(
tmp_dir
,
"checkpoint-25"
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
25
])
trainer
.
state
.
best_model_checkpoint
=
os
.
path
.
join
(
tmp_dir
,
"checkpoint-5"
)
self
.
check_checkpoint_deletion
(
trainer
,
tmp_dir
,
[
5
,
25
])
def
check_mem_metrics
(
self
,
trainer
,
check_func
):
metrics
=
trainer
.
train
().
metrics
check_func
(
"init_mem_cpu_alloc_delta"
,
metrics
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment